diff --git a/.flake8 b/.flake8 index 1656330a998..eeec63740e8 100644 --- a/.flake8 +++ b/.flake8 @@ -5,3 +5,4 @@ max-line-length = 119 # E402: module level import not at top of file per-file-ignores = __init__.py:F401,F403,E402 + fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E121,E131,E266 diff --git a/.github/workflows/CheckPRTemplate.yml b/.github/workflows/CheckPRTemplate.yml index e5b3dcd3ad9..ba1cd676a03 100644 --- a/.github/workflows/CheckPRTemplate.yml +++ b/.github/workflows/CheckPRTemplate.yml @@ -10,7 +10,8 @@ jobs: check: name: Check PR Template if: ${{ github.repository_owner == 'PaddlePaddle' }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: PR_ID: ${{ github.event.pull_request.number }} BASE_BRANCH: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/Codestyle-Check.yml b/.github/workflows/Codestyle-Check.yml index 6811e3fb38d..0470068d417 100644 --- a/.github/workflows/Codestyle-Check.yml +++ b/.github/workflows/Codestyle-Check.yml @@ -10,7 +10,8 @@ jobs: pre-commit: name: Pre Commit if: ${{ github.repository_owner == 'PaddlePaddle' }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: PR_ID: ${{ github.event.pull_request.number }} BRANCH: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/_accuracy_test.yml b/.github/workflows/_accuracy_test.yml index 4efb008da17..e2c8d40dbfe 100644 --- a/.github/workflows/_accuracy_test.yml +++ b/.github/workflows/_accuracy_test.yml @@ -69,12 +69,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -145,7 +160,10 @@ jobs: docker rm -f ${runner_name} || true fi - docker run --rm --ipc=host --pid=host --net=host \ + docker run --rm --net=host \ + --shm-size=64g \ + --sysctl kernel.msgmax=1048576 \ + --sysctl kernel.msgmnb=268435456 \ --name ${runner_name} \ -v $(pwd):/workspace \ -w /workspace \ @@ -160,8 +178,9 @@ jobs: -v "${CACHE_DIR}/.cache:/root/.cache" \ -v "${CACHE_DIR}/ConfigDir:/root/.config" \ -e TZ="Asia/Shanghai" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -204,3 +223,10 @@ jobs: fi echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" exit ${TEST_EXIT_CODE} + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index b114bad15d4..9f75b2b4b35 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -81,7 +81,14 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' @@ -111,7 +118,11 @@ jobs: exit 1 fi - tar -xf FastDeploy.tar.gz + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -200,8 +211,9 @@ jobs: -v "${CACHE_DIR}/.cache:/root/.cache" \ -v "${CACHE_DIR}/ConfigDir:/root/.config" \ -e TZ="Asia/Shanghai" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -254,13 +266,13 @@ jobs: curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ -H "Content-Type: application/json" \ - -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }" + -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--workers\": 1, \"--max-concurrency\": 5, \"--max-waiting-time\": 1 }" check_service 90 python -m pytest -sv test_max_concurrency.py || TEST_EXIT_CODE=1 curl -X POST http://0.0.0.0:${FLASK_PORT}/switch \ -H "Content-Type: application/json" \ - -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }" + -d "{ \"--model\": \"/MODELDATA/ERNIE-4.5-0.3B-Paddle\", \"--workers\": 1, \"--max-concurrency\": 5000, \"--max-waiting-time\": 1 }" check_service 90 python -m pytest -sv test_max_waiting_time.py || TEST_EXIT_CODE=1 @@ -294,3 +306,10 @@ jobs: fi echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" exit ${TEST_EXIT_CODE} + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux.yml b/.github/workflows/_build_linux.yml index 172f07cfd73..5865a3cc7fd 100644 --- a/.github/workflows/_build_linux.yml +++ b/.github/workflows/_build_linux.yml @@ -125,6 +125,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy Build shell: bash env: @@ -156,7 +157,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ @@ -171,6 +173,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -193,7 +196,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ else - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -248,3 +251,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} echo "wheel_path=${WHEEL_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux_cu129.yml b/.github/workflows/_build_linux_cu129.yml index 6370268c7cb..aabf5bb16a9 100644 --- a/.github/workflows/_build_linux_cu129.yml +++ b/.github/workflows/_build_linux_cu129.yml @@ -112,6 +112,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy Build shell: bash env: @@ -143,7 +144,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ @@ -158,6 +160,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -180,7 +183,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ else - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu129/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -235,3 +238,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} echo "wheel_path_cu129=${WHEEL_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux_cu130.yml b/.github/workflows/_build_linux_cu130.yml index 278aff6956b..a294c3557e4 100644 --- a/.github/workflows/_build_linux_cu130.yml +++ b/.github/workflows/_build_linux_cu130.yml @@ -112,6 +112,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy Build shell: bash env: @@ -143,7 +144,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache_cu130:/root/.cache" \ @@ -158,6 +160,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -180,7 +183,7 @@ jobs: elif [[ "${PADDLEVERSION}" != "" ]];then python -m pip install paddlepaddle-gpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu130/ else - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu130/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda130-Cudnn913-Trt1013-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu130/ fi pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -235,3 +238,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} echo "wheel_path_cu130=${WHEEL_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux_fd_router.yml b/.github/workflows/_build_linux_fd_router.yml index b600cc2328e..9e93290d509 100644 --- a/.github/workflows/_build_linux_fd_router.yml +++ b/.github/workflows/_build_linux_fd_router.yml @@ -107,6 +107,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy FD_ROUTER Build shell: bash env: @@ -137,7 +138,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ @@ -151,6 +153,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -211,3 +214,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" FD_ROUTER_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/fd-router echo "fd_router_path=${FD_ROUTER_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_linux_rl.yml b/.github/workflows/_build_linux_rl.yml index ede288c805a..fb3a85a5685 100644 --- a/.github/workflows/_build_linux_rl.yml +++ b/.github/workflows/_build_linux_rl.yml @@ -8,7 +8,7 @@ on: description: "Build Images" required: true type: string - default: "iregistry.baidu-int.com/tiangexiao/base-images:paddlecloud-ubuntu24.04-gcc13.3-cuda12.9-cudnn9.9-bccl1.4.1.4-nccl2.26.5-openmpi4.1.5-FleetY13.0.0-rc2" + default: "iregistry.baidu-int.com/new_rl_infra/base-images:paddlecloud-ubuntu24.04-gcc13.3-cuda12.9-cudnn9.9-bccl1.4.1.4-nccl2.26.5-openmpi4.1.5-FleetY13.0.0-v2.4.0-rc1" FASTDEPLOY_ARCHIVE_URL: description: "URL of the compressed FastDeploy code archive." required: true @@ -52,9 +52,10 @@ on: wheel_path_rl: description: "Output path of the generated wheel" value: ${{ jobs.fd-build-rl.outputs.wheel_path_rl }} + jobs: fd-build-rl: - runs-on: [self-hosted, GPU-Build] + runs-on: [self-hosted, GPU-Build-RL] timeout-minutes: 360 outputs: wheel_path_rl: ${{ steps.set_output.outputs.wheel_path_rl }} @@ -107,6 +108,7 @@ jobs: git config --global user.name "FastDeployCI" git config --global user.email "fastdeploy_ci@example.com" git log -n 3 --oneline + - name: FastDeploy Build shell: bash env: @@ -137,7 +139,8 @@ jobs: PARENT_DIR=$(dirname "$WORKSPACE") echo "PARENT_DIR:$PARENT_DIR" docker run --rm --net=host \ - --cap-add=SYS_PTRACE --privileged --shm-size=64G \ + --cap-add=SYS_PTRACE --shm-size=64G \ + --name ${runner_name} \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache_rl:/root/.cache" \ @@ -151,6 +154,7 @@ jobs: -e "PADDLE_WHL_URL=${PADDLE_WHL_URL}" \ -e "BRANCH_REF=${BRANCH_REF}" \ -e "CCACHE_MAXSIZE=50G" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${gpu_id}\"" ${docker_image} /bin/bash -c ' if [[ -n "${FD_VERSION}" ]]; then export FASTDEPLOY_VERSION=${FD_VERSION} @@ -162,8 +166,7 @@ jobs: cd FastDeploy python -m pip uninstall paddlepaddle-gpu -y || true - wget -q --no-proxy https://paddle-qa.bj.bcebos.com/paddle-pipeline/Develop-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/latest/paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl - python -m pip install paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu129/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -202,3 +205,10 @@ jobs: target_path_stripped="${target_path#paddle-github-action/}" WHEEL_PATH=https://paddle-github-action.bj.bcebos.com/${target_path_stripped}/${fd_wheel_name} echo "wheel_path_rl=${WHEEL_PATH}" >> $GITHUB_OUTPUT + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_build_xpu.yml b/.github/workflows/_build_xpu.yml index b9bab8381d0..1222f040812 100644 --- a/.github/workflows/_build_xpu.yml +++ b/.github/workflows/_build_xpu.yml @@ -159,7 +159,7 @@ jobs: python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ else python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y - python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.5.0.dev20260507-cp310-cp310-linux_x86_64.whl fi diff --git a/.github/workflows/_golang_router_test.yml b/.github/workflows/_golang_router_test.yml index 4964f3a3a05..62810d527a0 100644 --- a/.github/workflows/_golang_router_test.yml +++ b/.github/workflows/_golang_router_test.yml @@ -76,12 +76,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -191,12 +206,13 @@ jobs: -e "fd_router_url=${fd_router_url}" \ -e "BASE_REF=${BASE_REF}" \ -e "IS_PR=${IS_PR}" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt @@ -211,3 +227,10 @@ jobs: bash scripts/run_golang_router.sh ' + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_gpu_4cards_case_test.yml b/.github/workflows/_gpu_4cards_case_test.yml index 02a16b8b93b..c4b771c15fd 100644 --- a/.github/workflows/_gpu_4cards_case_test.yml +++ b/.github/workflows/_gpu_4cards_case_test.yml @@ -81,12 +81,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -186,13 +201,14 @@ jobs: -e "fd_wheel_url=${fd_wheel_url}" \ -e "BASE_REF=${BASE_REF}" \ -e "IS_PR=${IS_PR}" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt @@ -204,3 +220,10 @@ jobs: export CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/run_gpu_4cards.sh ' + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_logprob_test_linux.yml b/.github/workflows/_logprob_test_linux.yml index 47486cef243..5ccd0be40fa 100644 --- a/.github/workflows/_logprob_test_linux.yml +++ b/.github/workflows/_logprob_test_linux.yml @@ -78,11 +78,27 @@ jobs: if ls /workspace/* >/dev/null 2>&1; then echo "ERROR: Failed to clean /workspace/* after multiple attempts" ls -ld /workspace/* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls /workspace/* >/dev/null 2>&1; then + echo "ERROR: Force cleanup failed. Exiting..." + exit 1 + else + echo "Force cleanup succeeded." + fi fi ' - wget -q --no-proxy ${paddletest_archive_url} - tar -xf PaddleTest.tar.gz + + wget -q --no-proxy ${paddletest_archive_url} || { + echo "ERROR: Failed to download archive from ${paddletest_archive_url}" + exit 1 + } + + tar --no-same-owner -xf PaddleTest.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf PaddleTest.tar.gz cd PaddleTest git config --global user.name "FastDeployCI" @@ -171,8 +187,9 @@ jobs: -v "${CACHE_DIR}/.cache:/root/.cache" \ -v "${CACHE_DIR}/ConfigDir:/root/.config" \ -e TZ="Asia/Shanghai" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -223,3 +240,10 @@ jobs: run: | echo "logprob test failed with exit code ${{ env.LOGPROB_EXIT_CODE }}" exit 8 + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_pre_ce_test.yml b/.github/workflows/_pre_ce_test.yml index 9e313606a36..0669d503d80 100644 --- a/.github/workflows/_pre_ce_test.yml +++ b/.github/workflows/_pre_ce_test.yml @@ -83,12 +83,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -182,10 +197,18 @@ jobs: -e "FD_ZMQ_SEND_RESPONSE_SERVER_PORT=${FD_ZMQ_SEND_RESPONSE_SERVER_PORT}" \ -e "FD_ZMQ_CONTROL_CMD_SERVER_PORTS=${FD_ZMQ_CONTROL_CMD_SERVER_PORTS}" \ -e "fd_wheel_url=${fd_wheel_url}" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ python -m pip install ${fd_wheel_url} bash scripts/run_pre_ce.sh ' + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_stable_test.yml b/.github/workflows/_stable_test.yml index dd4ce4e811d..8678490f9d7 100644 --- a/.github/workflows/_stable_test.yml +++ b/.github/workflows/_stable_test.yml @@ -81,12 +81,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -176,8 +191,9 @@ jobs: -v "${CACHE_DIR}/.cache:/root/.cache" \ -v "${CACHE_DIR}/ConfigDir:/root/.config" \ -e TZ="Asia/Shanghai" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus '"device='"${DEVICES}"'"' ${docker_image} /bin/bash -xc ' - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple @@ -221,3 +237,10 @@ jobs: fi echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" exit ${TEST_EXIT_CODE} + + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} diff --git a/.github/workflows/_unit_test_coverage.yml b/.github/workflows/_unit_test_coverage.yml index 75aef3e937a..1cb1ca41213 100644 --- a/.github/workflows/_unit_test_coverage.yml +++ b/.github/workflows/_unit_test_coverage.yml @@ -86,12 +86,27 @@ jobs: if ls "${REPO_NAME}"* >/dev/null 2>&1; then echo "ERROR: Failed to clean ${REPO_NAME}* after multiple attempts" ls -ld "${REPO_NAME}"* - exit 1 + echo "Attempting force cleanup with find..." + find /workspace -mindepth 1 -maxdepth 1 -name "${REPO_NAME}*" -type d -exec chmod -R u+rwx {} \; -exec rm -rf {} + 2>/dev/null || true + if ls "${REPO_NAME}"* >/dev/null 2>&1; then + echo "ERROR: Force cleanup still failed" + exit 1 + else + echo "Force cleanup succeeded" + fi fi ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz + wget -q --no-proxy ${fd_archive_url} || { + echo "ERROR: Failed to download archive from ${fd_archive_url}" + exit 1 + } + + tar --no-same-owner -xf FastDeploy.tar.gz || { + echo "ERROR: Failed to extract archive" + exit 1 + } + rm -rf FastDeploy.tar.gz cd FastDeploy git config --global user.name "FastDeployCI" @@ -178,10 +193,12 @@ jobs: --sysctl kernel.msgmnb=268435456 \ --name ${runner_name} \ --cap-add=SYS_PTRACE --cap-add=IPC_LOCK \ - --shm-size=64G \ + --shm-size=128G \ ${RDMA_DEVICES} \ --device=/dev/infiniband/rdma_cm \ --ulimit memlock=-1:-1 \ + --ulimit nofile=65536:65536 \ + --ulimit nproc=8192:8192 \ -v $(pwd):/workspace -w /workspace \ -v "${CACHE_DIR}/gitconfig:/etc/gitconfig:ro" \ -v "${CACHE_DIR}/.cache:/root/.cache" \ @@ -201,12 +218,13 @@ jobs: -e "fd_wheel_url=${fd_wheel_url}" \ -e "BASE_REF=${BASE_REF}" \ -e "IS_PR=${IS_PR}" \ + -e "no_proxy=localhost,127.0.0.1,0.0.0.0,bcebos.com,.bcebos.com,bj.bcebos.com,su.bcebos.com,paddle-ci.gz.bcebos.com,apiin.im.baidu.com,baidu-int.com,.baidu.com,aliyun.com,gitee.com,pypi.tuna.tsinghua.edu.cn,.tuna.tsinghua.edu.cn" \ --gpus "\"device=${DEVICES}\"" ${docker_image} /bin/bash -c ' git config --global --add safe.directory /workspace/FastDeploy cd FastDeploy git diff origin/${BASE_REF}..HEAD --unified=0 > diff.txt - python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ + python -m pip install https://paddle-qa.bj.bcebos.com/paddle-pipeline/Release-TagBuild-Training-Linux-Gpu-Cuda12.6-Cudnn9.5-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/2b9f8b689bc8988f97a5ede056c8c81bfa0332c2/paddlepaddle_gpu-3.3.1.post20260420+2b9f8b689bc-cp310-cp310-linux_x86_64.whl --extra-index-url https://www.paddlepaddle.org.cn/packages/stable/cu126/ pip config set global.extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple python -m pip install -r scripts/unittest_requirement.txt @@ -376,7 +394,7 @@ jobs: wget -O ${filename} ${diff_cov_result_json_url} || echo "Download cov json file failed, but continuing..." fi if [ -f "${filename}" ];then - echo "Failed test cases:" + echo "GPU Patch Coverage Details:" if command -v jq >/dev/null 2>&1; then jq . "${filename}" else @@ -388,14 +406,25 @@ jobs: echo "coverage passed" exit 0 + - name: Terminate and delete the container + if: always() + run: | + set +e + docker exec -t ${{ runner.name }} /bin/bash -c 'find /workspace -mindepth 1 -delete' + docker rm -f ${{ runner.name }} + diff_coverage_report: needs: run_tests_with_coverage if: always() - runs-on: ubuntu-latest + runs-on: + group: APPROVAL timeout-minutes: 15 env: all_cov_file_url: ${{ needs.run_tests_with_coverage.outputs.all_cov_file_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Clone FastDeploy uses: actions/checkout@v6 with: diff --git a/.github/workflows/_xpu_4cards_case_test.yml b/.github/workflows/_xpu_4cards_case_test.yml index f3c97f40dc6..0548ea2afcb 100644 --- a/.github/workflows/_xpu_4cards_case_test.yml +++ b/.github/workflows/_xpu_4cards_case_test.yml @@ -178,7 +178,7 @@ jobs: python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ else python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y - python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.5.0.dev20260507-cp310-cp310-linux_x86_64.whl fi echo "安装上游任务编译的fastdeploy-xpu..." python -m pip install ${FASTDEPLOY_WHEEL_URL} @@ -213,6 +213,7 @@ jobs: - name: Upload case logs if: always() + continue-on-error: true uses: actions/upload-artifact@v6 with: name: xpu-4cards-case-logs diff --git a/.github/workflows/_xpu_8cards_case_test.yml b/.github/workflows/_xpu_8cards_case_test.yml index c9ed0fa2314..a0afceab1ad 100644 --- a/.github/workflows/_xpu_8cards_case_test.yml +++ b/.github/workflows/_xpu_8cards_case_test.yml @@ -167,7 +167,7 @@ jobs: python -m pip install paddlepaddle-xpu==${PADDLEVERSION} -i https://www.paddlepaddle.org.cn/packages/stable/xpu-p800/ else python -m pip uninstall paddlepaddle-xpu fastdeploy-xpu -y - python -m pip install --pre paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/ + python -m pip install https://paddle-whl.bj.bcebos.com/nightly/xpu-p800/paddlepaddle-xpu/paddlepaddle_xpu-3.5.0.dev20260507-cp310-cp310-linux_x86_64.whl fi echo "安装上游任务编译的fastdeploy-xpu..." python -m pip install ${FASTDEPLOY_WHEEL_URL} @@ -201,6 +201,7 @@ jobs: - name: Upload case logs if: always() + continue-on-error: true uses: actions/upload-artifact@v6 with: name: xpu-8cards-case-logs diff --git a/.github/workflows/approve.yml b/.github/workflows/approve.yml index 6de30d6f564..39b1844da74 100644 --- a/.github/workflows/approve.yml +++ b/.github/workflows/approve.yml @@ -13,11 +13,15 @@ jobs: Approval: name: Approval if: ${{ github.repository_owner == 'PaddlePaddle' }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: PR_ID: ${{ github.event.pull_request.number }} BRANCH: ${{ github.event.pull_request.base.ref }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Checkout base repo uses: actions/checkout@v6 with: diff --git a/.github/workflows/cancel_ci_iluvatar.yml b/.github/workflows/cancel_ci_iluvatar.yml index 9dba9a7d1e0..1bb5ae247d4 100644 --- a/.github/workflows/cancel_ci_iluvatar.yml +++ b/.github/workflows/cancel_ci_iluvatar.yml @@ -13,8 +13,12 @@ concurrency: jobs: cancel: name: Cancel ILUVATAR-CI for ${{ github.event.pull_request.number }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Cancel ILUVATAR-CI run: | exit 0 diff --git a/.github/workflows/cancel_ci_xpu.yml b/.github/workflows/cancel_ci_xpu.yml index befd59796e9..dab6c9ce79a 100644 --- a/.github/workflows/cancel_ci_xpu.yml +++ b/.github/workflows/cancel_ci_xpu.yml @@ -13,8 +13,12 @@ concurrency: jobs: cancel: name: Cancel CI_XPU for ${{ github.event.pull_request.number }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Cancel CI_XPU run: | exit 0 diff --git a/.github/workflows/cancel_pr_build_and_test.yml b/.github/workflows/cancel_pr_build_and_test.yml index bb488a529ea..0cc0f3d0671 100644 --- a/.github/workflows/cancel_pr_build_and_test.yml +++ b/.github/workflows/cancel_pr_build_and_test.yml @@ -12,8 +12,12 @@ concurrency: jobs: cancel: name: Cancel PR Build and Test for ${{ github.event.pull_request.number }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Cancel PR Build and Test run: | exit 0 diff --git a/.github/workflows/ce_job.yml b/.github/workflows/ce_job.yml index 5b20eccdf2e..c9ce4400ae3 100644 --- a/.github/workflows/ce_job.yml +++ b/.github/workflows/ce_job.yml @@ -14,7 +14,8 @@ concurrency: jobs: ce_job_pre_check: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: COMPILE_BRANCH: ${{ vars.COMPILE_BRANCH }} CE_COMPILE_SELECTION: ${{ vars.CE_COMPILE_SELECTION }} @@ -26,6 +27,9 @@ jobs: sm8090_match: ${{ steps.set_output.outputs.sm8090_match }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Set Version id: set_output env: @@ -78,9 +82,13 @@ jobs: done print_ce_job_pre_check_outputs: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL needs: ce_job_pre_check steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print outputs as JSON run: | echo '${{ toJSON(needs.ce_job_pre_check.outputs) }}' @@ -89,12 +97,16 @@ jobs: clone: environment: CodeSync name: FD-Clone-Linux - runs-on: ubuntu-latest + runs-on: + group: APPROVAL needs: ce_job_pre_check if: ${{ needs.ce_job_pre_check.outputs.branch_match == 'true' }} outputs: repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Clone FastDeploy uses: actions/checkout@v6 with: @@ -154,8 +166,12 @@ jobs: resultshow: name: Show Code Archive Output needs: clone - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print repo_archive_url path run: | echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" @@ -186,7 +202,7 @@ jobs: COMPILE_ARCH: "80,90" WITH_NIGHTLY_BUILD: OFF FD_VERSION: 0.0.0 - PADDLE_WHL_URL: https://paddle-qa.bj.bcebos.com/paddle-pipeline/Paddle-RL-Compile/develop/latest/paddlepaddle_gpu-3.3.0.dev-cp310-cp310-linux_x86_64.whl + PADDLE_WHL_URL: https://paddle-qa.bj.bcebos.com/paddle-pipeline/Develop-TagBuild-Training-Linux-Gpu-Cuda12.9-Cudnn9.9-Trt10.5-Mkl-Avx-Gcc11-SelfBuiltPypiUse/latest/paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl build_sm8689: name: BUILD_SM8689 @@ -207,13 +223,17 @@ jobs: environment: CodeSync name: CE_UPLOAD needs: build_sm8090 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090.outputs.wheel_path }} COMPILE_ARCH: "80,90" steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -257,13 +277,17 @@ jobs: environment: CodeSync name: CE_UPLOAD_RL needs: build_sm8090_rl - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8090_rl.outputs.wheel_path_rl }} COMPILE_ARCH: "80,90" steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -303,13 +327,17 @@ jobs: environment: CodeSync name: CE_UPLOAD needs: build_sm8689 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_sm8689.outputs.wheel_path }} COMPILE_ARCH: "86,89" steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' diff --git a/.github/workflows/check-bypass.yml b/.github/workflows/check-bypass.yml index c9256e7a6cf..a799bbe3a41 100644 --- a/.github/workflows/check-bypass.yml +++ b/.github/workflows/check-bypass.yml @@ -18,7 +18,8 @@ on: jobs: check-bypass: name: Check bypass - runs-on: ubuntu-latest + runs-on: + group: APPROVAL permissions: contents: read env: @@ -64,7 +65,9 @@ jobs: exit 0 fi - files=$(gh pr view ${{ github.event.pull_request.number }} --repo ${{ github.repository }} --json files --jq '.files[].path') + files=$(curl -s -H "Authorization: token $GITHUB_TOKEN" \ + "https://api.github.com/repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/files?per_page=100" \ + | jq -r '.[].filename') echo "$files" can_skip_docs=true diff --git a/.github/workflows/cherry-pick.yml b/.github/workflows/cherry-pick.yml index c6e1bad992e..407acbea687 100644 --- a/.github/workflows/cherry-pick.yml +++ b/.github/workflows/cherry-pick.yml @@ -22,8 +22,12 @@ jobs: github.event.action == 'labeled' || contains(join(github.event.pull_request.labels.*.name, ' '), 'cherry-pick') ) - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Checkout uses: actions/checkout@v6 with: diff --git a/.github/workflows/ci_image_update.yml b/.github/workflows/ci_image_update.yml index 762cad91023..ae6b1b5d0e8 100644 --- a/.github/workflows/ci_image_update.yml +++ b/.github/workflows/ci_image_update.yml @@ -16,10 +16,14 @@ jobs: clone: environment: CodeSync name: FD-Clone-Linux - runs-on: ubuntu-latest + runs-on: + group: APPROVAL outputs: repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Clone FastDeploy uses: actions/checkout@v6 with: @@ -64,8 +68,12 @@ jobs: resultshow: name: Show Code Archive Output needs: clone - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print wheel path run: | echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" diff --git a/.github/workflows/ci_metax.yml b/.github/workflows/ci_metax.yml deleted file mode 100644 index 5584147eb8c..00000000000 --- a/.github/workflows/ci_metax.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: CI_METAX - -on: - pull_request_target: - types: - - opened - - synchronize - branches: - - develop - - release/** - -permissions: - contents: read - -concurrency: - group: jenkins-pr-${{ github.event.pull_request.number }} - cancel-in-progress: true - -jobs: - trigger-jenkins: - name: Trigger Jenkins for PR - runs-on: ubuntu-latest - environment: Metax_ci - - steps: - - name: Trigger Jenkins job - timeout-minutes: 120 - uses: MetaX-MACA/simple-jenkins-githubaction@v1.1 - with: - job_name: paddle_fastdeploy_metax_smoketest - username: fastdeploy_builder - api_token: ${{ secrets.METAX_JENKINS_API_TOKEN }} - pr_number: ${{ github.event.pull_request.number }} - project_branch: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 6c06ed0a6aa..17a64cf1d88 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -9,8 +9,12 @@ permissions: jobs: deploy: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/checkout@v6 - uses: actions/setup-python@v6 with: diff --git a/.github/workflows/pr_build_and_test.yml b/.github/workflows/pr_build_and_test.yml index 9ffcd75ee5c..bbad1ee939c 100644 --- a/.github/workflows/pr_build_and_test.yml +++ b/.github/workflows/pr_build_and_test.yml @@ -32,8 +32,12 @@ jobs: resultshow: name: Use Build Output needs: build - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print wheel path run: | echo "The built wheel is located at: ${{ needs.build.outputs.wheel_path }}" diff --git a/.github/workflows/publish_job.yml b/.github/workflows/publish_job.yml index 9207d58a497..e5b98392665 100644 --- a/.github/workflows/publish_job.yml +++ b/.github/workflows/publish_job.yml @@ -19,7 +19,8 @@ concurrency: jobs: publish_pre_check: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL if: | github.event.repository.fork == false && ( @@ -40,6 +41,9 @@ jobs: compile_use_paddle_whl_url: ${{ steps.set_output.outputs.compile_use_paddle_whl_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Get tag version if: github.ref_type == 'tag' run: | @@ -108,9 +112,13 @@ jobs: echo "with_nightly_build=${with_nightly_build:-OFF}" >> $GITHUB_OUTPUT print_publish_pre_check_outputs: - runs-on: ubuntu-latest + runs-on: + group: APPROVAL needs: publish_pre_check steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print outputs as JSON run: | echo '${{ toJSON(needs.publish_pre_check.outputs) }}' @@ -118,12 +126,16 @@ jobs: clone: environment: CodeSync name: FD-Clone-Linux - runs-on: ubuntu-latest + runs-on: + group: APPROVAL needs: publish_pre_check if: ${{ needs.publish_pre_check.outputs.compile_continue == 'true' }} outputs: repo_archive_url: ${{ steps.set_output.outputs.repo_archive_url }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Clone FastDeploy uses: actions/checkout@v6 with: @@ -168,8 +180,12 @@ jobs: resultshow: name: Show Code Archive Output needs: clone - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Print wheel path run: | echo "The code archive is located at: ${{ needs.clone.outputs.repo_archive_url }}" @@ -235,12 +251,16 @@ jobs: environment: CodeSync name: CE_UPLOAD_FD_ROUTER needs: build_fd_router - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FD_ROUTER_URL: ${{ needs.build_fd_router.outputs.fd_router_path }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -291,12 +311,16 @@ jobs: environment: PaddleSourceUpload name: PADDLE_PYPI_UPLOAD_cu126 needs: build_cu126 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu126.outputs.wheel_path }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -323,12 +347,16 @@ jobs: environment: PaddleSourceUpload name: PADDLE_PYPI_UPLOAD_cu129 needs: build_cu129 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu129.outputs.wheel_path_cu129 }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' @@ -355,12 +383,16 @@ jobs: environment: PaddleSourceUpload name: PADDLE_PYPI_UPLOAD_cu130 needs: build_cu130 - runs-on: ubuntu-latest + runs-on: + group: APPROVAL env: AK: ${{ secrets.BOS_AK }} SK: ${{ secrets.BOS_SK }} FASTDEPLOY_WHEEL_URL: ${{ needs.build_cu130.outputs.wheel_path_cu130 }} steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - uses: actions/setup-python@v6 with: python-version: '3.10' diff --git a/.github/workflows/remove-skip-ci-labels.yml b/.github/workflows/remove-skip-ci-labels.yml index 978f70ea240..aace7ae5af6 100644 --- a/.github/workflows/remove-skip-ci-labels.yml +++ b/.github/workflows/remove-skip-ci-labels.yml @@ -10,8 +10,12 @@ permissions: jobs: remove-skip-ci-labels: name: Remove skip-ci labels on new commits - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: + - name: Cleanup + run: | + rm -rf * .[^.]* - name: Get PR labels id: get-labels uses: actions/github-script@v8 diff --git a/.github/workflows/rerun.yml b/.github/workflows/rerun.yml index bbc96edd37e..6527ccf0679 100644 --- a/.github/workflows/rerun.yml +++ b/.github/workflows/rerun.yml @@ -7,7 +7,8 @@ on: jobs: re-run: if: ${{ github.event.issue.pull_request && contains(github.event.comment.body, '/re-run') && github.event.comment.user.login == github.event.issue.user.login }} - runs-on: ubuntu-latest + runs-on: + group: APPROVAL steps: - name: Cleanup run: | diff --git a/README_CN.md b/README_CN.md index 4d110b8d830..33642687062 100644 --- a/README_CN.md +++ b/README_CN.md @@ -27,7 +27,9 @@ ## 最新活动 -**[2026-01] FastDeploy v2.4 全新发布!** 新增 DeepSeek V3 与 Qwen3-MoE 模型的 PD 分离部署,增强MTP 投机解码能力,全面优化多硬件平台上的 MoE 推理与多模态前缀缓存性能,升级全部内容参阅 [v2.4 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.4.0)。 +**[2026-03] FastDeploy v2.5 全新发布!** 新增Qwen3-VL与Qwen3-VL MoE模型部署支持,新增W4AFP8量化方法,增强强化学习训练支持能力,包含170+项Bug修复与性能优化,升级全部内容参阅 [v2.5 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.5.0)。 + +**[2026-01] FastDeploy v2.4**: 新增 DeepSeek V3 与 Qwen3-MoE 模型的 PD 分离部署,增强MTP 投机解码能力,全面优化多硬件平台上的 MoE 推理与多模态前缀缓存性能,升级全部内容参阅 [v2.4 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.4.0)。 **[2025-11] FastDeploy v2.3**: 新增[ERNIE-4.5-VL-28B-A3B-Thinking](docs/zh/get_started/ernie-4.5-vl-thinking.md)与[PaddleOCR-VL-0.9B](docs/zh/best_practices/PaddleOCR-VL-0.9B.md)两大重磅模型在多硬件平台上的部署支持,进一步优化全方位推理性能,以及带来更多部署功能和易用性的提升,升级全部内容参阅[v2.3 ReleaseNote](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.3.0)。 diff --git a/README_EN.md b/README_EN.md index 4d918455d5f..72c8cf1a1ac 100644 --- a/README_EN.md +++ b/README_EN.md @@ -27,7 +27,9 @@ English | [简体中文](README_CN.md) ## News -[2026-01] FastDeploy v2.4 is released! Featuring PD-separated deployment for DeepSeek V3 and Qwen3-MoE, enhanced MTP speculative decoding, and comprehensive performance boosts for MoE inference and multi-modal Prefix Caching across various hardware backends. See the full v2.4 ReleaseNote for more details. +**[2026-03] FastDeploy v2.5 is released!** It adds deployment support for Qwen3-VL and Qwen3-VL MoE models, introduces the W4AFP8 quantization method, enhances reinforcement learning training capabilities, and includes 170+ bug fixes and performance optimizations. For all the upgrade details, refer to the [v2.5 Release Note](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.5.0). + +**[2026-01] FastDeploy v2.4**: Featuring PD-separated deployment for DeepSeek V3 and Qwen3-MoE, enhanced MTP speculative decoding, and comprehensive performance boosts for MoE inference and multi-modal Prefix Caching across various hardware backends. For all the upgrade details, refer to the [v2.4 Release Note](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.4.0). **[2025-11] FastDeploy v2.3**: It adds deployment support for two major models, [ERNIE-4.5-VL-28B-A3B-Thinking](docs/get_started/ernie-4.5-vl-thinking.md) and [PaddleOCR-VL-0.9B](docs/best_practices/PaddleOCR-VL-0.9B.md), across multiple hardware platforms. It further optimizes comprehensive inference performance and brings more deployment features and usability enhancements. For all the upgrade details, refer to the [v2.3 Release Note](https://github.com/PaddlePaddle/FastDeploy/releases/tag/v2.3.0). diff --git a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu index 963ccfa23d9..e25816fcbb3 100644 --- a/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu @@ -146,10 +146,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv, rope_3d); } else { if (rotary_dim < dim_head) { - auto* kernelFn = - append_decode_cache_T_neox_partial_rope_kernel; + auto* kernelFn = append_decode_cache_T_neox_partial_rope_kernel< + T, + PackSize, + false>; // GLM use EnforceFmulRN=false launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, diff --git a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh index 0cdea537327..60d5d34bf48 100644 --- a/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh @@ -2543,10 +2543,10 @@ void gqa_rotary_qk_variable( } const int pack_num_new = elem_nums / PackSize; GetNumBlocks<128>(pack_num_new, &grid_size); - auto *kernelFn = - GQANeoxVariableLengthPartialRotaryKernel; + auto *kernelFn = GQANeoxVariableLengthPartialRotaryKernel< + T, + PackSize, + false>; // GLM use EnforceFmulRN=false launchWithPdlWhenEnabled(kernelFn, grid_size, blocksize, diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index f94e8493f7f..d61aa3c2313 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -296,7 +296,7 @@ void GetBlockShapeAndSplitKVBlock( if (!phi::backends::gpu::IsCUDAGraphCapturing()) #endif max_len_tensor_cpu.copy_( - max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + max_len_tensor_gpu, max_len_tensor_cpu.place(), true); auto max_len_cpu_ptr = max_len_tensor_cpu.data(); int max_len_this_time = max_len_cpu_ptr[0]; @@ -378,7 +378,7 @@ void GetBlockShapeAndSplitKVBlock( if (!phi::backends::gpu::IsCUDAGraphCapturing()) #endif decoder_num_blocks_cpu.copy_( - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), true); } } // mla_backend not need run the following code. @@ -409,7 +409,7 @@ void GetBlockShapeAndSplitKVBlock( block_size); kv_num_blocks_x_cpu.copy_( - kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false); + kv_num_blocks_x, kv_num_blocks_x_cpu.place(), true); // Clear buffer const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q); @@ -433,7 +433,7 @@ void GetBlockShapeAndSplitKVBlock( encoder_block_shape_q, group_size); encoder_num_blocks_x_cpu.copy_( - encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false); + encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), true); } } diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index e4d0554fea6..c86ec27dca8 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -387,30 +387,32 @@ void gqa_neox_partial_rotary_qk_split_variable( const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2; - launchWithPdlWhenEnabled( - GQAVariableLengthNeoxPartialRotarySplitKernel, - grid_size, - block_size, - 0, - stream, - qkv_input, - cos_emb, - sin_emb, - batch_id_per_token, - cu_seqlens_q, - seq_lens_encoder, - seq_lens_decoder, - cu_seqlens_k, - qkv_out, - q, - k, - v, - elem_nums, - num_heads, - kv_num_heads, - max_model_len, - head_dim, - rotary_dim); + launchWithPdlWhenEnabled(GQAVariableLengthNeoxPartialRotarySplitKernel< + T, + PackSize, + false>, // GLM use EnforceFmulRN=false + grid_size, + block_size, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + max_model_len, + head_dim, + rotary_dim); } template PreCacheLenConcat( bsz, block_size); paddle::Tensor pre_cache_num_blocks_cpu = - pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false); + pre_cache_num_blocks.copy_to(paddle::CPUPlace(), true); paddle::Tensor kv_token_num_cpu = - kv_token_num.copy_to(paddle::CPUPlace(), false); + kv_token_num.copy_to(paddle::CPUPlace(), true); return { cu_seqlens_k, diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu index e87289a74ec..4ee00f12e07 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu @@ -130,10 +130,11 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv, GetNumBlocks(pack_num, &grid_size); if (use_neox_style) { if (rotary_dim < dim_head) { - append_speculate_cache_neox_partial_rope_kernel + append_speculate_cache_neox_partial_rope_kernel< + T, + PackSize, + QKV_TYPE, + false> // GLM use EnforceFmulRN=false <<>>( qkv, // [token_num, num_heads + 2 * gqa_group_size, head_size] key_cache, diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 40898434bf1..0a65b6de6d3 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -189,6 +189,84 @@ std::vector AppendAttentionWithOutput( const int sliding_window, const int sink_size); +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder); + +std::vector DecodeUnifiedAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + paddle::Tensor& fmha_out, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window); + +void ConfigForAttention(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& block_indices, // Inplace + paddle::Tensor& num_blocks, // Inplace + paddle::Tensor& chunk_size, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch); + std::vector GQARopeWriteCacheKernel( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, @@ -537,13 +615,21 @@ std::vector TextImageGatherScatter( const bool is_scatter); std::vector count_tokens_per_expert_func( - const paddle::Tensor& topk_ids, int64_t num_experts); -void GetPositionIdsAndMaskEncoderBatch( - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& position_ids, - const paddle::Tensor& mask_encoder_batch); + const paddle::Tensor& topk_ids, + int64_t num_experts, + bool compute_padded_cumsum = false); +void GetPositionIds(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids); +void GetPositionIdsAndSlotMapping(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& block_tables, + const paddle::Tensor& position_ids, + const paddle::Tensor& slot_mapping, + const int block_size); std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& kv_nope, @@ -691,6 +777,19 @@ std::vector NoauxTc(paddle::Tensor& scores, bool renormalize, float routed_scaling_factor); +std::vector grouped_topk( + paddle::Tensor& gating_output, + paddle::Tensor& e_score_correction_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor); + +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias, + std::string cast_type); + std::vector NoauxTcRedundant( paddle::Tensor& scores, paddle::Tensor& scores_with_bias, @@ -781,6 +880,11 @@ std::vector BuildSamplingParams( const int64_t token_num_output_cpu, const int64_t increment_value); +std::vector BuildSamplingParamLogProb( + const paddle::Tensor& input_params, + const paddle::Tensor& token_num_per_batch, + int64_t token_num_output_cpu); + void SpecTokenPenaltyMultiScores( const paddle::Tensor& token_ids_all, const paddle::Tensor& prompt_lens, @@ -1140,13 +1244,16 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_encoder); -void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, - const paddle::Tensor& logits, - const paddle::Tensor& cu_batch_token_offset, - const paddle::Tensor& ori_cu_batch_token_offset, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& accept_num); +void SpeculateGetAcceptTokensAndLogits( + const paddle::Tensor& token_ids, + const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num, + const paddle::Tensor& accept_tokens); std::vector UpdateAttnMaskOffsets( const paddle::Tensor& ids_remove_padding, @@ -1631,9 +1738,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("is_zp_float")); #endif - m.def("get_position_ids_and_mask_encoder_batch", - &GetPositionIdsAndMaskEncoderBatch, - "get_position_ids_and_mask_encoder_batch function"); + m.def("get_position_ids", &GetPositionIds, "get_position_ids function"); + m.def("get_position_ids_and_slot_mapping", + &GetPositionIdsAndSlotMapping, + "get_position_ids_and_slot_mapping function"); /** * cutlass_scaled_mm.cu @@ -1694,6 +1802,15 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("noaux_tc", &NoauxTc, "noaux_tc for Deepseekv3 MoE compute"); + m.def("grouped_topk", &grouped_topk, "fused grouped topk for MoE routing"); + + m.def("fused_cast_sigmoid_bias", + &FusedCastSigmoidBias, + "Fused cast+sigmoid+bias for MoE gating scores", + py::arg("input"), + py::arg("bias"), + py::arg("cast_type") = std::string("float32")); + m.def("noaux_tc_redundant", &NoauxTcRedundant, "noaux_tc_redundant for MoE compute"); @@ -1769,6 +1886,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &BuildSamplingParams, "build_sampling_params function"); + m.def("build_sampling_params_logprob", + &BuildSamplingParamLogProb, + "build_sampling_params_logprob function"); + m.def("speculate_get_token_penalty_multi_scores", &SpecTokenPenaltyMultiScores, "speculate_get_token_penalty_multi_scores function"); @@ -1870,9 +1991,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) { &SpeculateInsertFirstToken, "speculate_insert_first_token function"); - m.def("speculate_get_target_logits", - &SpeculateGetTargetLogits, - "speculate_get_target_logits function"); + m.def("speculate_get_accept_tokens_and_logits", + &SpeculateGetAcceptTokensAndLogits, + "speculate_get_accept_tokens_and_logits function"); #endif m.def("update_attn_mask_offsets", @@ -1930,4 +2051,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("per_token_group_fp8_quant", &PerTokenGroupQuantFp8, "per_token_group_quant_fp8"); + + /** + * decoder_write_cache_with_rope.cu + * decoder_write_cache_with_rope + */ + m.def("decoder_write_cache_with_rope", + &DecoderWriteCacheWithRoPE, + "decoder write cache with RoPE function"); + + /** + * decode_unified_attention.cu + * decode_unified_attention + */ + m.def("decode_unified_attention", + &DecodeUnifiedAttention, + "decoder append attention function"); + + /** + * config_for_attention.cu + * config_for_attention + */ + m.def("config_for_attention", + &ConfigForAttention, + "config for attention function"); } diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 9c5e7bfc47b..7e93f169028 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -635,7 +635,7 @@ struct MoeFCGemm { static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1010) +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1100) static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, shared_storage); @@ -1060,7 +1060,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 800) && (__CUDA_ARCH__ < 1010) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 1100) KernelRunner::run_kernel( params, shared_storage); #else diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h index db5af4f4938..68b5b054476 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_gemm_kernels_template.h @@ -709,7 +709,7 @@ void MoeGemmRunner::dispatch_to_arch( dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm70); } else if (sm_ >= 75 && sm_ < 80) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm75); - } else if (sm_ >= 80 && sm_ < 101) { + } else if (sm_ >= 80 && sm_ < 104) { dispatch_moe_gemm_to_cutlass_macro(cutlass::arch::Sm80); } else { throw std::runtime_error( diff --git a/custom_ops/gpu_ops/decode_unified_attention.cu b/custom_ops/gpu_ops/decode_unified_attention.cu new file mode 100644 index 00000000000..257134d1e95 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention.cu @@ -0,0 +1,428 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decode_unified_attention/decode_unified_attention_c8_impl.cuh" +#include "decode_unified_attention/decode_unified_attention_c16_impl.cuh" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecodeUnifiedAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + paddle::Tensor& fmha_out, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_num = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + const int max_kv_len_this_time = set_max_lengths.data()[5]; + + auto stream = qkv.stream(); + bool is_fp8 = + cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8"; + bool is_dynamic_cfp8 = cache_quant_type == "block_wise_fp8"; + bool is_c16 = cache_quant_type == "none"; + + if (max_just_dec_len_this_time > 0) { + if (is_c16) { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, max_tokens_per_batch, Q_TILE_SIZE, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeUnifiedC16Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeUnifiedC16Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})}) + } else { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, + max_tokens_per_batch, + Q_TILE_SIZE, + {DISPATCH_DyCfp8( + is_dynamic_cfp8, + IsDynamicC8, + {DISPATCH_IS_FP8(is_fp8, IsFP8, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeUnifiedC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeUnifiedC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})})})}) + } + } + return {fmha_out}; +} + +std::vector> DecodeUnifiedAttentionInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& tmp_workspace_shape, + const std::vector& tmp_m_shape, + const std::vector& tmp_d_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& block_indices_shape, + const std::vector& num_blocks_shape, + const std::vector& chunk_size_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& mask_offset_shape, + const paddle::optional>& sinks_shape, + const std::vector& fmha_out_shape, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {fmha_out_shape}; +} + +std::vector DecodeUnifiedAttentionInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& tmp_workspace_dtype, + const paddle::DataType& tmp_m_dtype, + const paddle::DataType& tmp_d_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& block_indices_dtype, + const paddle::DataType& num_blocks_dtype, + const paddle::DataType& chunk_size_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& mask_offset_dtype, + const paddle::optional& sinks_dtype, + const paddle::DataType& fmha_out_dtype, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {fmha_out_dtype}; +} + +PD_BUILD_STATIC_OP(decode_unified_attention) + .Inputs({"qkv", + "key_cache", + "value_cache", + "tmp_workspace", + "tmp_m", + "tmp_d", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "block_indices", + "num_blocks", + "chunk_size", + "set_max_lengths", + paddle::Optional("attn_mask"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("mask_offset"), + paddle::Optional("sinks"), + "fmha_out"}) + .Outputs({"fmha_out_out"}) + .SetInplaceMap({{"fmha_out", "fmha_out_out"}}) + .Attrs({ + "cache_quant_type: std::string", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "max_tokens_per_batch: int", + "causal: bool", + "sliding_window: int", + }) + .SetKernelFn(PD_KERNEL(DecodeUnifiedAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(DecodeUnifiedAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecodeUnifiedAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh b/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh new file mode 100644 index 00000000000..ee74570e5d8 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh @@ -0,0 +1,1231 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "mma_tensor_op.cuh" +#include "utils.cuh" + +template +__device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = 0.f; + } + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + if constexpr (std::is_same::value) { + m[fx][j] = -5e4f; + } else if constexpr (std::is_same::value) { + m[fx][j] = -3.0e+30f; + } + d[fx][j] = 1.f; + } + } +} + +template +__device__ __forceinline__ void load_block_table_per_chunk( + const int32_t* block_table_chunk_start, + int32_t* block_table_smem, + uint32_t chunk_start, + uint32_t chunk_end, + uint32_t tid, + uint32_t wid) { + uint32_t len = chunk_end / BLOCK_SIZE - chunk_start / BLOCK_SIZE; + for (uint32_t i = 0; i < div_up(len, 128); i++) { + uint32_t offset = wid * kWarpSize + tid + i * 128; + if (offset < len) { + block_table_smem[offset] = block_table_chunk_start[offset]; + } + } +} + +// load q from global memory to shared memory +template +__device__ __forceinline__ void load_q_global_smem_multi_warps( + T* q_ptr_base, + smem_t* q_smem, + uint32_t q_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] + smem_t::get_permuted_offset(ty * 4 + tx / 8, + tx % 8); // 4 * 64 + + const uint32_t tx_offset = tx / 8; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); + q_smem_offset_w = + q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); + q_ptr += 8 * num_elems_per_128b(); + } + q_smem_offset_w = + q_smem->advance_offset_by_row<16, num_vecs_per_head>(q_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( + smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] + const float sm_scale) { + constexpr int vec_size = 16 / sizeof(T); + using LoadT = AlignedVector; + LoadT tmp_vec; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + +#pragma unroll + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { + const int offset = i * 1024 + ty * 256 + tx * 8; + Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + tmp_vec[reg_id] *= sm_scale; + } + Store(tmp_vec, reinterpret_cast(q_smem->base) + offset); + } +} + +template +__device__ __forceinline__ void produce_k_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_k, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_k_offset) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = + head_dim / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; +#pragma unroll + for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { + smem.load_128b_async(*smem_offset, cache_k_now, true); + *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( + *smem_offset, j); + cache_k_now += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_y; // num_frags_y / 4 * 4 + cache_k_now += num_warps * 4 * kv_b_stride - + num_frags_y * num_elems_per_128b(); + } + } + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void produce_v_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_v, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_d_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_v_offset) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b(); + +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; + +#pragma unroll + for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; + ++i) { // m (num_frags_y * 16 / (num_warps * 8)) +#pragma unroll + for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) { + smem.load_128b_async(*smem_offset, cache_v_now, true); + *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + *smem_offset, j); + cache_v_now += 4 * num_elems_per_128b(); + kv_idx += 4 * num_elems_per_128b(); + } + kv_idx -= 2 * num_frags_z * num_elems_per_128b(); + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_z; // num_frags_z / 4 * 4 + cache_v_now += num_warps * 8 * kv_d_stride - + 2 * num_frags_z * num_elems_per_128b(); + } + kv_idx += block_size; + } + *smem_offset -= NUM_WARP_KV / 2 * num_frags_y * 16 * num_vecs_per_blocksize; +} + +template +__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async( + smem_t kv_scale_smem, + const int* block_table_now, + const T* cache_kv_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t tid = ty * 32 + tx; + // 1 warp 32 tokens + if (tid < block_size / 8 * 2) { + const uint32_t kv_idx_now = kv_idx + block_size * tid / 8; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const int kv_idx_this_thread = kv_idx + tid * 8; + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid % 8 * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); + } +} + +template +__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg( + T* k_smem_scale, T* cache_k_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + // 1 warp 32 tokens + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } +} + +template +__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg( + T* v_smem_scale, T* cache_v_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + // 1 warp 32 tokens + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; + } +} + +template +__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + const T* cache_k_scale, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + head_dim / num_elems_per_128b(); + + uint32_t a_frag[num_frags_x][2][4], b_frag[4], b_frag_dq[4]; + +#pragma unroll + for (uint32_t ky = 0; ky < num_frags_y / 2; ++ky) { // k + // load q +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]); + + *q_smem_offset_r = + q_smem->advance_offset_by_row<16, num_vecs_per_head_q>( + *q_smem_offset_r); + } + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 2 + fy) - + num_frags_x * 16 * num_vecs_per_head_q; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + // load + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fy * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + const int scale_col = (ky * 2 + fy) * 4; + b_frag_dq_T[0] *= cache_k_scale[scale_col]; + b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_k_scale[scale_col]; + b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[0]; + } + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4]; + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (ky == 0 && fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } + } + } + } + *k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>( + *k_smem_offset_r, ky) - + num_frags_z * 16 * num_vecs_per_head_k; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y / 2 * 2; +} + +template +__device__ __forceinline__ void mask_s(const bool* attn_mask, + const uint32_t qo_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t chunk_end, + const uint32_t attn_mask_len, + float (*s_frag)[num_frags_z][8], + const int* mask_offset = nullptr, + const int sliding_window = 0) { + const uint32_t tx = threadIdx.x; +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 + + 8 * ((reg_id % 4) / 2)) / + group_size, + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + bool out_of_boundary; + if (mask_offset) { + const int2 mo = reinterpret_cast(mask_offset)[q_idx]; + out_of_boundary = + q_idx < qo_len ? (kv_idx >= mo.y || kv_idx < mo.x) : true; + } else if (sliding_window > 0) { + bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - + (int)qo_len - sliding_window; + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + out_of_window || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } else { + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && + kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = + q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + bool mask = attn_mask[mask_idx]; + out_of_boundary |= mask; + } + } + + if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; + } else if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; + } + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t j_id = j * 2; + float m_prev = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_tmp = s_frag[fx][fz] + j_id; + float m_local = max(max(s_frag_tmp[0], s_frag_tmp[1]), + max(s_frag_tmp[4], s_frag_tmp[5])); + m[fx][j] = max(m[fx][j], m_local); + } + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x2, 32)); + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x1, 32)); + float o_scale = expf(m_prev - m[fx][j]); + d[fx][j] *= o_scale; + float2 fp2_scale = make_float2(o_scale, o_scale); +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_frag_ptr = reinterpret_cast(o_frag[fx][fy] + j_id); + o_frag_ptr[0] = fast_float2_mul(o_frag_ptr[0], fp2_scale); + o_frag_ptr[2] = fast_float2_mul(o_frag_ptr[2], fp2_scale); + } + float tmp_m = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_ptr = s_frag[fx][fz] + j_id; + s_frag_ptr[0] = __expf(s_frag_ptr[0] - tmp_m); + s_frag_ptr[1] = __expf(s_frag_ptr[1] - tmp_m); + s_frag_ptr[4] = __expf(s_frag_ptr[4] - tmp_m); + s_frag_ptr[5] = __expf(s_frag_ptr[5] - tmp_m); + } + } + } +} + +template +__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2], + T* cache_v_scale) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; + uint32_t b_frag[4], b_frag_dq[4]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag); + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>( + *v_smem_offset_r); +#pragma unroll + for (uint32_t fz = 0; fz < 2; ++fz) { + // dequant b_frag -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } + } + } else { + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], + (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), + b_frag_dq); + } + } + } + *v_smem_offset_r -= num_frags_y * 16 * num_vecs_per_blocksize; + } +} + +template +__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], + float* md_smem, + float (*m)[2], + float (*d)[2], + const uint32_t wid, + const uint32_t tid, + const bool normalize = false) { + // Padded row stride (33 instead of 32) to avoid cross-row bank conflicts. + constexpr uint32_t kRowStride = 33; + // o_smem row stride in floats: kRowStride * 8 = 264 + constexpr uint32_t kORowStride = kRowStride * 8; + // md_smem base offset: after all o_smem data + // NUM_WARPS(4) * num_frags_x * num_frags_y * kORowStride floats + constexpr uint32_t kOMemFloats = 4 * num_frags_x * num_frags_y * kORowStride; + float2* smem_md = reinterpret_cast(md_smem + kOMemFloats); + + // Phase 1: Write m/d to smem only (2KB, no o data yet) +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + smem_md[((wid * num_frags_x + fx) * 2 + j) * kRowStride + tid] = + make_float2(m[fx][j], d[fx][j]); + } + } + __syncthreads(); + + // Phase 2: Compute global m/d and scale own o_frag in registers + float scale_j[2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_new; + float d_new = 1.f; + if constexpr (std::is_same::value) { + m_new = -5e4f; + } else { + m_new = -3.0e+30f; + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = + smem_md[((i * num_frags_x + fx) * 2 + j) * kRowStride + tid]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = fmaf(d_prev, expf(m_prev - m_new), md.y * expf(md.x - m_new)); + } + float own_scale = expf(m[fx][j] - m_new); + m[fx][j] = m_new; + d[fx][j] = d_new; + float d_rcp = normalize ? (1.f / d_new) : 1.f; + scale_j[j] = own_scale * d_rcp; + } + // Apply scale to o_frag using WGMMA fragment layout: + // regs 0,1→j=0, 2,3→j=1, 4,5→j=0, 6,7→j=1 + // i.e., float2 index k → j = k % 2 +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t k = 0; k < 4; ++k) { + float s = scale_j[k % 2]; + o_frag[fx][fy][2 * k + 0] *= s; + o_frag[fx][fy][2 * k + 1] *= s; + } + } + } + + // Phase 3: Write pre-scaled o_frag to smem with padded stride +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_smem_start = + (float2*)(md_smem + + ((wid * num_frags_x + fx) * num_frags_y + fy) * + kORowStride + + tid * 2); +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + o_smem_start[i * kRowStride] = ((float2*)(&o_frag[fx][fy][0]))[i]; + } + } + } + __syncthreads(); + + // Phase 4: Accumulate all warps' scaled o_frag +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_new_fp2 = reinterpret_cast(&o_frag[fx][fy][0]); +#pragma unroll + for (uint32_t o_id = 0; o_id < 4; ++o_id) { + o_new_fp2[o_id] = make_float2(0.f, 0.f); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + AlignedVector oi_fp2; + float2* o_smem_start = + (float2*)(md_smem + + ((i * num_frags_x + fx) * num_frags_y + fy) * + kORowStride + + tid * 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + oi_fp2[reg_id] = o_smem_start[reg_id * kRowStride]; + } +#pragma unroll + for (uint32_t reg_fp2_id = 0; reg_fp2_id < 4; ++reg_fp2_id) { + o_new_fp2[reg_fp2_id].x += oi_fp2[reg_fp2_id].x; + o_new_fp2[reg_fp2_id].y += oi_fp2[reg_fp2_id].y; + } + } + } + } +} + +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + float d_rcp[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[fx][j] = 1.f / d[fx][j]; + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = + o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2]; + } + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem_multi_warps( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + OutT* o_ptr_base, + uint32_t o_idx_base, + const uint32_t q_head_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr int VEC_SIZE = 16 / sizeof(T); + // [num_warps * num_frags_x * 16, num_frags_y * 16] + if (ty == 0) { + // [num_frags_x * 16, num_frags_y * 16] +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = + o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3]; + } + } + } + __syncthreads(); + + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); + + const uint32_t tx_offset = tx / 8; +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + + OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + if (n_offset < qo_upper_bound) { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * num_elems_per_128b(); + o_smem_offset_w = + o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = + o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +struct prefill_softmax_state_t { + AlignedVector o; + float m; + float d; + + __device__ __forceinline__ void init() { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.38953e38f; + } + } + + __device__ __forceinline__ void merge( + const AlignedVector& other_o, float other_m, float other_d) { + float m_prev = m, d_prev = d; + m = m_prev > other_m ? m_prev : other_m; + const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d_prev * scale1 + other_d * scale2; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * scale1_T + other_o[i] * scale2_T; + } + } + + __device__ __forceinline__ void normalize() { + const T d_t = static_cast(d); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } + + __device__ __forceinline__ void normalize(float current_sink) { + const T d_t = static_cast(d + __expf(current_sink - m)); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } +}; + +// C16 (fp16/bf16 KV cache) helper functions + +template +__device__ __forceinline__ void produce_kv_blockwise(smem_t smem, + uint32_t* smem_offset, + T** gptr, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 4; ++j) { + smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_y; + *gptr += + num_warps * 4 * kv_b_stride - 2 * num_frags_y * num_elems_per_128b(); + } + *gptr -= NUM_WARP_KV * num_frags_z * 16 * kv_b_stride; + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void compute_qk(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + uint32_t a_frag[num_frags_x][4], b_frag[4]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); + *q_smem_offset_r = q_smem->advance_offset_by_row<16, num_vecs_per_head>( + *q_smem_offset_r); + } + + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, fy) - + num_frags_x * 16 * num_vecs_per_head; + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } + } + } + *k_smem_offset_r = + k_smem->advance_offset_by_column<2>(*k_smem_offset_r, fy) - + num_frags_z * 16 * num_vecs_per_head; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y * 2; +} + +template +__device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t b_frag[4]; + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag); + } + *v_smem_offset_r = + v_smem->advance_offset_by_column<2>(*v_smem_offset_r, fy); + } + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_head>(*v_smem_offset_r) - + 2 * num_frags_y; + } + *v_smem_offset_r -= 16 * num_frags_z * num_vecs_per_head; +} + +template +__global__ void merge_chunks_kernel( + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, + // head_dim] + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + const int* __restrict__ chunk_size_ptr, + T* __restrict__ out, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int head_dim, + const int token_num, + const int max_tokens_per_batch = 5) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int hid = blockIdx.y; + // After intra-warp reduction, only bdy/2 results need smem storage + __shared__ T smem[(bdy / 2) * HEAD_DIM]; + __shared__ float md_smem[(bdy / 2) * 2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + // Phase 1: Fast path — all ty participate independently (no smem, no + // syncthreads) Each ty handles a different qid with stride gridDim.x * bdy + using LoadT = AlignedVector; + for (int qid = blockIdx.x + ty * gridDim.x; qid < token_num; + qid += gridDim.x * bdy) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == (uint32_t)-1) continue; + if (seq_lens_encoder[bid] > 0) continue; // skip prefill batches + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + if (num_chunks_this_seq != 1) continue; // handled in Phase 2 + + LoadT load_vec; + uint32_t offset = + ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * num_heads + + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + Store( + load_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + + // Phase 2: Slow path — merge multi-chunk results + // Optimization: use warp-shuffle reduction within each warp, then cross-warp + // via smem. This eliminates the large smem[bdy * HEAD_DIM] buffer and reduces + // syncthreads from 2 per qid to 1 per qid. + // Block layout: (blockx=16, bdy=8) => 4 warps, each warp has 2 ty values + // Warp 0: ty=0,1 Warp 1: ty=2,3 Warp 2: ty=4,5 Warp 3: ty=6,7 + // Lane layout within warp: lanes 0-15 = (ty_low, vid), lanes 16-31 = + // (ty_high, vid) + const int lane_id = (ty * blockDim.x + vid) % 32; + + for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == (uint32_t)-1) continue; // uniform skip — no syncthreads needed + if (seq_lens_encoder[bid] > 0) continue; + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + if (num_chunks_this_seq == 1) continue; // handled in Phase 1 + + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } + + // Step 1: Each ty iterates over its chunk subset and does local online + // softmax merge +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset; + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks + i) * + num_heads + + hid; + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * + num_heads + + i * num_heads + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } + + // Step 2: Intra-warp reduction via warp shuffle + // Each warp has 2 ty values: ty_low at lanes 0-15, ty_high at lanes 16-31 + // Merge ty_high into ty_low using shuffle + const int partner_lane = lane_id ^ 16; // flip bit 4 to swap low/high ty + const float m_partner = __shfl_sync(0xffffffff, m, partner_lane); + const float d_partner = __shfl_sync(0xffffffff, d, partner_lane); + // Pack adjacent 16-bit pairs into 32-bit for efficient shuffle. + // AlignedVector alignment >= 4 bytes, so uint32 reinterpret is safe + // — no OOB read, no type confusion. This halves shuffle count vs + // per-element memcpy for bf16/fp16. + constexpr int PACKED_SIZE = vec_size * sizeof(T) / sizeof(unsigned); + const unsigned* packed_res = reinterpret_cast(&res_vec); + unsigned packed_partner[PACKED_SIZE]; +#pragma unroll + for (int j = 0; j < PACKED_SIZE; j++) { + packed_partner[j] = __shfl_sync(0xffffffff, packed_res[j], partner_lane); + } + LoadT partner_vec; + memcpy(&partner_vec, packed_partner, sizeof(partner_vec)); + + // Merge partner into self (only the "low ty" keeps the result) + float m_new = max(m, m_partner); + const float scale1 = __expf(m - m_new); + const float scale2 = __expf(m_partner - m_new); + float d_new = d * scale1 + d_partner * scale2; + if ((ty & 1) == 0) { // low ty keeps merged result + m = m_new; + d = d_new; + const T scale1_T = static_cast(scale1); + const T scale2_T = static_cast(scale2); +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + partner_vec[j] * scale2_T; + } + } + + // Cross-warp: only even ty (0,2,4,6) write to smem + if ((ty & 1) == 0) { + Store(res_vec, &smem[(ty / 2) * head_dim + vid * vec_size]); + md_smem[ty] = m; + md_smem[ty + 1] = d; + } + __syncthreads(); + + if (ty == 0) { + prefill_softmax_state_t st; + st.init(); +#pragma unroll + for (int i = 0; i < bdy / 2; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + + if (sinks) { + float current_sink = static_cast(sinks[hid]); + st.normalize(current_sink); + } else { + st.normalize(); + } + + const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, + &smooth_weight_vec); + } + +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); + } + Store( + out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + __syncthreads(); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu b/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu new file mode 100644 index 00000000000..7033cbd10bf --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu @@ -0,0 +1,409 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "cute/tensor.hpp" +#include "helper.h" +#include "paddle/extension.h" +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" +#include "paddle/phi/core/memory/memcpy.h" +#endif +#include "utils.cuh" + +template +__global__ void GetMaxLenKernel(const int* seq_lens_decoder, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + int* max_lens, + const int batch_size) { + const int tid = threadIdx.x; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int max_len_this_time_this_thread = 0; + int max_len_encoder_this_thread = 0; + int max_len_decoder_this_thread = 0; + int max_len_this_thread = 0; + int max_just_dec_len_this_thread = 0; + int max_len_kv_this_thread = 0; + for (int i = tid; i < batch_size; i += blockDim.x) { + const int seq_len_this_time = seq_lens_this_time[i]; + const int seq_len_decoder = seq_lens_decoder[i]; + max_len_this_time_this_thread = + max(seq_len_this_time, max_len_this_time_this_thread); + max_len_encoder_this_thread = + max(seq_lens_encoder[i], max_len_encoder_this_thread); + max_len_decoder_this_thread = + max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; + const int max_just_dec_len_now = + seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; + max_len_this_thread = + max(seq_len_decoder + seq_len_this_time, max_len_this_thread); + max_just_dec_len_this_thread = + max(max_just_dec_len_this_thread, max_just_dec_len_now); + + if (seq_len_decoder == 0) continue; + max_len_kv_this_thread = + max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); + } + int total_max_len_this_time = + BlockReduce(temp_storage) + .Reduce(max_len_this_time_this_thread, MaxOp()); + int total_max_len_encoder = + BlockReduce(temp_storage) + .Reduce(max_len_encoder_this_thread, MaxOp()); + int total_max_len_decoder = + BlockReduce(temp_storage) + .Reduce(max_len_decoder_this_thread, MaxOp()); + int total = + BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); + int total_just_dec = BlockReduce(temp_storage) + .Reduce(max_just_dec_len_this_thread, MaxOp()); + int total_max_len_kv = + BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); + if (tid == 0) { + max_lens[0] = total_max_len_this_time; + max_lens[1] = total_max_len_encoder; + max_lens[2] = total_max_len_decoder; + max_lens[3] = total; + max_lens[4] = total_just_dec; + max_lens[5] = total_max_len_kv; + } +} + +template +__global__ void config_decode_attn(const int* __restrict__ seq_lens_this_time, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + int4* __restrict__ block_indices, + int* __restrict__ num_blocks, + int* __restrict__ chunk_size, + const int bsz, + const int group_size, + const int kv_num_heads, + const int q_tile_size, + const int max_tokens_per_batch, + const int config_gridx) { + const int tid = threadIdx.x, wid = threadIdx.y; + const uint32_t warp_size = blockDim.x; + __shared__ int num_block_all_shared[block_size]; + __shared__ int chunk_size_res[1]; + + const int lane_id = tid + wid * warp_size; + + // Merged Step 1+2: single bsz loop computing both Scheme E metrics and + // split-KV block counts per lane. Avoids redundant seq_lens reads and + // shared intermediate values (token_num, kv_len, q_tile_num). + const int target_blocks = config_gridx / 3; // sm_count * 3 + // Search chunk_size from 512 with step 128: {512, 640, 768, ...} + + const int cur_chunk_size = + min(min_chunk_size + lane_id * chunk_step, max_chunk_size); + int num_block_no_chunk = 0; + int max_kv_len_no_chunk = 0; + int num_block_all = 0; + for (int bid = 0; bid < bsz; bid++) { + if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0) { + continue; + } + const int token_num_cur_batch = seq_lens_this_time[bid]; + const int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch; + const int q_tile_num = + div_up(token_num_cur_batch * group_size, q_tile_size); + num_block_no_chunk += q_tile_num * kv_num_heads; + max_kv_len_no_chunk = max(max_kv_len_no_chunk, kv_len_cur_batch); + const int kv_chunk_num = div_up(kv_len_cur_batch, cur_chunk_size); + num_block_all += q_tile_num * kv_chunk_num * kv_num_heads; + } + num_block_all_shared[lane_id] = num_block_all; + __syncthreads(); + + // Step 3: find best chunk_size, then decide Scheme E vs split-KV + if (tid == 0 && wid == 0) { + // Strategy: + // 1. Must fill target_blocks (2*sm_count) to maintain SM concurrency + // 2. Among valid choices, prefer minimum per-SM max KV traffic + // (= waves * chunk_size, since kernel time = slowest SM) + // 3. Within 5% of minimum KV traffic, prefer larger chunk_size + int chunk_size_best = min_chunk_size; + int num_block_all_best = num_block_all_shared[0]; + // Step 1: find minimum kv_traffic among chunk_sizes that fill SMs + int64_t kv_traffic_min = INT64_MAX; + for (int i = 0; i < static_cast(block_size); i++) { + const int nb = num_block_all_shared[i]; + if (nb < target_blocks) continue; + const int cs = min(min_chunk_size + i * chunk_step, max_chunk_size); + const int w = div_up(nb, target_blocks); + const int64_t kv_traffic = static_cast(w) * cs; + if (kv_traffic < kv_traffic_min) { + kv_traffic_min = kv_traffic; + } + } + // Step 2: if no chunk_size fills SMs, fall back to smallest + if (kv_traffic_min == INT64_MAX) { + chunk_size_best = min_chunk_size; + num_block_all_best = num_block_all_shared[0]; + } else { + // Step 3: scan from largest chunk_size downward; accept the first + // one that fills SMs AND has kv_traffic within 20% of minimum + for (int i = block_size - 1; i >= 0; i--) { + const int nb = num_block_all_shared[i]; + if (nb < target_blocks) continue; + const int cs = min(min_chunk_size + i * chunk_step, max_chunk_size); + const int w = div_up(nb, target_blocks); + const int64_t kv_traffic = static_cast(w) * cs; + if (kv_traffic <= kv_traffic_min + kv_traffic_min / 4) { + chunk_size_best = cs; + num_block_all_best = nb; + break; + } + } + } + + // Decide Scheme E: prefer when blocks fill SMs AND estimated latency + // is no worse than split-KV. + // Scheme E: waves_E * max_kv_len (few heavy blocks) + // Split-KV: waves_split * chunk_size_best (many light blocks) + // When no splitting is needed (num_block_all_best == num_block_no_chunk), + // Scheme E is strictly better (saves merge overhead). + bool use_scheme_e = false; + if (num_block_no_chunk >= target_blocks) { + if (num_block_all_best == num_block_no_chunk) { + use_scheme_e = true; + } else { + // target_blocks = sm_count * 3 ≈ CTAs per wave (sm_count × occupancy). + // Using target_blocks as denominator correctly accounts for occupancy + // in wave count estimation. + const int waves_e = div_up(num_block_no_chunk, target_blocks); + const int waves_split = div_up(num_block_all_best, target_blocks); + use_scheme_e = (static_cast(waves_e) * max_kv_len_no_chunk <= + static_cast(waves_split) * chunk_size_best); + } + } + + if (use_scheme_e) { + num_blocks[0] = num_block_no_chunk; + chunk_size[0] = INT_MAX; + chunk_size_res[0] = INT_MAX; + } else { + num_blocks[0] = num_block_all_best; + chunk_size[0] = chunk_size_best; + chunk_size_res[0] = chunk_size_best; + } + } + + __syncthreads(); + if (wid == 0) { + const int chunk_size_final = chunk_size_res[0]; + + int prev_offset = 0; + for (int base = 0; base < bsz; base += warp_size) { + const int bid = base + tid; + int num_block_cur = 0; + int q_tile_num = 0; + int kv_chunk_num = 0; + + if (bid < bsz) { + int token_num_cur_batch = seq_lens_this_time[bid]; + if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { + token_num_cur_batch = 0; + } + q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size); + const int kv_len_cur_batch = + seq_lens_decoder[bid] + token_num_cur_batch; + kv_chunk_num = div_up(kv_len_cur_batch, chunk_size_final); + num_block_cur = q_tile_num * kv_chunk_num * kv_num_heads; + } + + // inclusive prefix sum + int x = num_block_cur; + for (int offset = 1; offset < warp_size; offset <<= 1) { + int y = __shfl_up_sync(0xffffffff, x, offset); + if (tid >= offset) x += y; + } + int bid_offset = x - num_block_cur; + int tile_sum = __shfl_sync(0xffffffff, x, warp_size - 1); + + // Write block_indices using int4 vectorized stores. + // Each entry is exactly 4 ints (bid, kv_head_id, kv_chunk_id, q_tile_id), + // matching int4 layout. This reduces 4 scalar stores to 1 vector store. + if (bid < bsz && num_block_cur > 0) { + int4* write_ptr = block_indices + prev_offset + bid_offset; + int flat_idx = 0; + const int kv_chunk_num_x_q_tile_num = kv_chunk_num * q_tile_num; +#pragma unroll 2 + for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) { + const int head_base = kv_head_id * kv_chunk_num_x_q_tile_num; +#pragma unroll 2 + for (int kv_chunk_id = 0; kv_chunk_id < kv_chunk_num; kv_chunk_id++) { + const int chunk_base = head_base + kv_chunk_id * q_tile_num; +#pragma unroll + for (int q_tile_id = 0; q_tile_id < q_tile_num; q_tile_id++) { + write_ptr[flat_idx] = + make_int4(bid, kv_head_id, kv_chunk_id, q_tile_id); + flat_idx++; + } + } + } + } + prev_offset += tile_sum; + } + } +} + +void ConfigForAttention( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& block_indices, // Inplace, shape:[block_num,4], block's + // indices with 4 dimension[batch_idx, + // kv_head_idx, kv_chunk_idx, q_tile_idx] + paddle::Tensor& num_blocks, // Inplace + paddle::Tensor& chunk_size, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + auto stream = seq_lens_encoder.stream(); + int bsz = seq_lens_this_time.shape()[0]; + + paddle::Tensor max_len_tensor_gpu = + GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, + paddle::DataType::INT32, + seq_lens_this_time.place()); + + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + auto max_len_cpu_ptr = max_len_tensor_cpu.data(); + int max_just_dec_len_this_time = max_len_cpu_ptr[4]; + + const uint32_t block_indices_ele_num = block_indices.size(); + + // decoder + if (max_just_dec_len_this_time > 0) { + CUDA_CHECK(cudaMemsetAsync(block_indices.data(), + 0, + block_indices_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK( + cudaMemsetAsync(num_blocks.data(), 0, sizeof(int32_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(chunk_size.data(), 0, sizeof(int32_t), stream)); + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK(cudaDeviceGetAttribute( + &sm_cout, cudaDevAttrMultiProcessorCount, device)); + const int config_gridx = sm_cout * 6; + + const int q_tile_size = 16; + dim3 blocks(32, 4); + // Cast block_indices to int4* for vectorized stores. + // Each block_indices entry is 4 ints = 16 bytes = sizeof(int4), + // and block_num * 4 ints = block_num int4s, so the reinterpret is valid. + int4* block_indices_i4 = reinterpret_cast(block_indices.data()); + if (cache_quant_type == "cache_int4_zp") { + config_decode_attn<512, 256, 128, 32768> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices_i4, + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } else { + config_decode_attn<512, 128, 128, 16384> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices_i4, + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } + } +} + +std::vector> ConfigForAttentionInferShape( + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& num_blocks_shape, + const std::vector& chunk_size_shape, + const std::vector& max_len_tensor_cpu_shape, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +std::vector ConfigForAttentionInferDtype( + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& num_blocks_dtype, + const paddle::DataType& chunk_size_dtype, + const paddle::DataType& max_len_tensor_cpu_dtype, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +PD_BUILD_STATIC_OP(config_for_attention) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "block_indices", + "num_blocks", + "chunk_size", + "max_len_tensor_cpu", + }) + .Outputs({ + + }) + .Attrs({"cache_quant_type: std::string", + "group_size: int", + "kv_num_heads: int", + "max_tokens_per_batch: int"}) + .SetKernelFn(PD_KERNEL(ConfigForAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(ConfigForAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ConfigForAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh b/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh new file mode 100644 index 00000000000..ff84e1cd3f6 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh @@ -0,0 +1,124 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +template +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template +CUtensorMap makeTensorMapForKVCache(T const* addr, + uint32_t block_num, + uint32_t kv_num_head, + uint32_t second_size, + uint32_t last_size) { + CUtensorMap tensorMap{}; + + uint32_t elem_bytes = sizeof(T); + + uint32_t const last_size_bytes = elem_bytes * last_size; + // VLLM Layout + CUtensorMapDataType data_dtype = cu_tensor_map_type_traits::type; + constexpr uint32_t rank = 4; + uint64_t global_dims[] = {last_size, second_size, kv_num_head, block_num}; + uint64_t global_strides[] = {last_size_bytes, + second_size * last_size_bytes, + kv_num_head * second_size * last_size_bytes}; + + uint32_t box_dims[] = {last_size, second_size, 1, 1}; + uint32_t elem_strides[] = {1, 1, 1, 1}; + + auto const swizzle = [&] { + switch (last_size_bytes) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache last_size"); + } + }(); + CUresult res = cuTensorMapEncodeTiled( + &tensorMap, + data_dtype, + rank, + reinterpret_cast(const_cast(addr)), + global_dims, + global_strides, + box_dims, + elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + switch (res) { + case CUDA_SUCCESS: + printf("CUDA_SUCCESS!\n"); + break; + case CUDA_ERROR_INVALID_VALUE: + printf("CUDA_ERROR_INVALID_VALUE\n"); + break; + case CUDA_ERROR_OUT_OF_MEMORY: + printf("CUDA_ERROR_OUT_OF_MEMORY\n"); + break; + case CUDA_ERROR_NOT_INITIALIZED: + printf("CUDA_ERROR_NOT_INITIALIZED\n"); + break; + case CUDA_ERROR_DEINITIALIZED: + printf("CUDA_ERROR_DEINITIALIZED\n"); + break; + case CUDA_ERROR_PROFILER_DISABLED: + printf("CUDA_ERROR_PROFILER_DISABLED\n"); + break; + default: + throw std::runtime_error("unsupported res!"); + } + + return tensorMap; +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh new file mode 100644 index 00000000000..e30588a01ab --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh @@ -0,0 +1,492 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +#include "attention_func.cuh" + +template +__global__ void decode_unified_attention_c16_kernel( + AttentionParams params) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // Cache loop-invariant params fields into registers. + // Pass-by-value (no __grid_constant__) allows the compiler to cache + // struct fields, and explicit local variables guarantee no constant + // cache pressure in the grid-stride loop. + // Only cache frequently-used fields; rarely-used ones are accessed + // via params.xxx to reduce register pressure (Scheme I-A.2). + const auto qkv = params.qkv; + const auto cache_k = params.cache_k; + const auto cache_v = params.cache_v; + const auto seq_lens_q = params.seq_lens_q; + const auto seq_lens_kv = params.seq_lens_kv; + const auto block_table = params.block_table; + const auto cu_seqlens_q = params.cu_seqlens_q; + const auto block_indices = params.block_indices; + const auto mask_offset = params.mask_offset; + const auto attn_mask = params.attn_mask; + const auto tmp_o = params.tmp_o; + const auto tmp_m = params.tmp_m; + const auto tmp_d = params.tmp_d; + const float softmax_scale = params.softmax_scale; + const int q_num_heads = params.q_num_heads; + const int kv_num_heads = params.kv_num_heads; + + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_frags_x * 16 + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + int4 indices = reinterpret_cast(block_indices)[lane_idx]; + int batch_idx = indices.x; + int kv_head_idx = indices.y; + int chunk_idx = indices.z; + int tile_idx = indices.w; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = seq_lens_q[batch_idx]; + const int* block_table_now = + block_table + batch_idx * params.max_blocks_per_seq; + + constexpr uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t kv_len = seq_lens_kv[batch_idx] + q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T* q_base_ptr = qkv + q_offset; + + T* o_base_ptr_T = tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int* mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const bool* attn_mask_this_seq = + attn_mask ? attn_mask + + batch_idx * params.attn_mask_len * params.attn_mask_len + : nullptr; + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, softmax_scale); + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + BLOCK_SIZE); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len, chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (BLOCK_SIZE); + + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid % 16, tid / 16); + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx = chunk_start; + int block_table_idx = kv_idx / BLOCK_SIZE; + int block_id = __ldg(&block_table_now[block_table_idx]); + int block_id_next = __ldg(&block_table_now[block_table_idx + 1]); + if (block_id_next < 0) { + block_id_next = 0; + } + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + T* cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T* cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + if (iter + 1 < num_iterations) { + block_id_next = __ldg(&block_table_now[block_table_idx + 1]); + if (block_id_next < 0) { + block_id_next = 0; + } + } + + wait_group<1>(); + __syncthreads(); + + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(attn_mask_this_seq, + q_base_seq_id_this_block, + kv_idx + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx += BLOCK_SIZE; + block_table_idx++; + + block_id = block_id_next; + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + __syncthreads(); + + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + const bool do_normalize = (num_chunks_this_seq <= 1); + merge_block_res(o_frag, + reinterpret_cast(smem), + m_frag, + d_frag, + wid, + tid, + do_normalize); + + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } + } +} + +template +void DecodeUnifiedC16Attention( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::optional& attn_mask, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const int max_seq_len, + const int max_dec_len, + const int max_tokens_per_batch, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; + constexpr uint32_t smem_size_0 = + (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * + sizeof(NV_TYPE); + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 33 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 33 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = + decode_unified_attention_c16_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + const int max_num_chunks = div_up(max_seq_len, 512); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = + reinterpret_cast(const_cast(cache_k.data())); + params.cache_v = + reinterpret_cast(const_cast(cache_v.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + params.batch_size = meta_data.batch_size; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids(sm_cout * 6); + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE*)nullptr, + (NV_TYPE*)nullptr, + sinks ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + 0.f, + 0.f, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh new file mode 100644 index 00000000000..00a20165555 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh @@ -0,0 +1,706 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +// #include "cu_tensor_map.cuh" +#include "attention_func.cuh" + +template +void print_params(AttentionParams const params) { + printf("max_model_len: %d\n", params.max_model_len); + printf("max_kv_len: %d\n", params.max_kv_len); + printf("max_blocks_per_seq: %d\n", params.max_blocks_per_seq); + printf("softmax_scale: %f\n", params.softmax_scale); + printf("quant_max_bound: %f\n", params.quant_max_bound); + printf("quant_min_bound: %f\n", params.quant_min_bound); + printf("max_tokens_per_batch: %d\n", params.max_tokens_per_batch); + printf("attn_mask_len: %d\n", params.attn_mask_len); + printf("sliding_window: %d\n", params.sliding_window); + printf("q_num_heads: %d\n", params.q_num_heads); + printf("kv_num_heads: %d\n", params.kv_num_heads); + printf("max_num_chunks: %d\n", params.max_num_chunks); + printf("max_tile_q: %d\n", params.max_tile_q); + printf("batch_size: %d\n", params.batch_size); +} + +template +__global__ void decode_unified_attention_c8_kernel( + AttentionParams params) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // Cache loop-invariant params fields into registers. + // Pass-by-value (no __grid_constant__) allows the compiler to cache + // struct fields, and explicit local variables guarantee no constant + // cache pressure in the grid-stride loop. + // Only cache frequently-used fields; rarely-used ones are accessed + // via params.xxx to reduce register pressure (Scheme I-A.2). + const auto qkv = params.qkv; + const auto cache_k = params.cache_k; + const auto cache_v = params.cache_v; + const auto cache_k_scale = params.cache_k_scale; + const auto cache_v_scale = params.cache_v_scale; + const auto seq_lens_q = params.seq_lens_q; + const auto seq_lens_kv = params.seq_lens_kv; + const auto block_table = params.block_table; + const auto cu_seqlens_q = params.cu_seqlens_q; + const auto block_indices = params.block_indices; + const auto mask_offset = params.mask_offset; + const auto attn_mask = params.attn_mask; + const auto tmp_o = params.tmp_o; + const auto tmp_m = params.tmp_m; + const auto tmp_d = params.tmp_d; + const float softmax_scale = params.softmax_scale; + const int q_num_heads = params.q_num_heads; + const int kv_num_heads = params.kv_num_heads; + + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + smem_t k_scale_smem; + smem_t v_scale_smem; + T* k_smem_scale_ptr = nullptr; + T* v_smem_scale_ptr = nullptr; + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + int4 indices = reinterpret_cast(block_indices)[lane_idx]; + int batch_idx = indices.x; + int kv_head_idx = indices.y; + int chunk_idx = indices.z; + int tile_idx = indices.w; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = seq_lens_q[batch_idx]; + const int* block_table_now = + block_table + batch_idx * params.max_blocks_per_seq; + + T cache_k_scale_reg[IsDynamicC8 + ? num_frags_z * 2 + : (is_scale_channel_wise ? num_frags_y * 4 : 1)]; + T cache_v_scale_reg[IsDynamicC8 + ? num_frags_z * 4 + : (is_scale_channel_wise ? num_frags_y * 2 : 1)]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T* cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T* cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + } + } + constexpr uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t kv_len = seq_lens_kv[batch_idx] + q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + T* o_base_ptr_T = nullptr; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T* q_base_ptr = qkv + q_offset; + + o_base_ptr_T = tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int* mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const bool* attn_mask_this_seq = + attn_mask ? attn_mask + + batch_idx * params.attn_mask_len * params.attn_mask_len + : nullptr; + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, softmax_scale); + + if constexpr (IsDynamicC8) { + k_smem_scale_ptr = reinterpret_cast( + smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); + } + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_idx * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, + (wid % 2) * num_frags_z + (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg(k_smem_scale_ptr, + cache_k_scale_reg); + } + + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(attn_mask_this_seq, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg(v_smem_scale_ptr, + cache_v_scale_reg); + } + + compute_sfm_v_c8_iter_sq_bvec( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + } + wait_group<0>(); + __syncthreads(); + const bool do_normalize = (num_chunks_this_seq <= 1); + merge_block_res(o_frag, + reinterpret_cast(smem), + m_frag, + d_frag, + wid, + tid, + do_normalize); + + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } + } +} + +template +void DecodeUnifiedC8Attention(const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::optional& attn_mask, + const paddle::Tensor& cache_k_scale, + const paddle::Tensor& cache_v_scale, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + + auto* allocator = paddle::GetAllocator(qkv.place()); + + bool is_scale_channel_wise = false; + if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { + is_scale_channel_wise = true; + } + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; + constexpr uint32_t smem_size_0 = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 33 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 33 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = decode_unified_attention_c8_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = decode_unified_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + const int max_num_chunks = div_up(max_seq_len, 512); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = const_cast(cache_k.data()); + params.cache_v = const_cast(cache_v.data()); + params.cache_k_scale = + reinterpret_cast(const_cast(cache_k_scale.data())); + params.cache_v_scale = + reinterpret_cast(const_cast(cache_v_scale.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.quant_max_bound = quant_max_bound; + params.quant_min_bound = quant_min_bound; + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + params.batch_size = meta_data.batch_size; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids(sm_cout * 6); + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE*)nullptr, + (NV_TYPE*)nullptr, + sinks ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh b/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh new file mode 100644 index 00000000000..18788858923 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh @@ -0,0 +1,389 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class SharedMemFillMode { kFillZero, kNoFill }; + +enum class PrefetchMode { kNoPrefetch, kPrefetch }; + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_impl(uint32_t* R, T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R, + T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +__device__ __forceinline__ void commit_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + {} +#else + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void wait_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + cooperative_groups::wait(cooperative_groups::this_thread_block()); +#else + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } +#else + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } + } +#endif +} + +template +__device__ __forceinline__ void pred_load_64b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 8); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_32b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 4); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4)); + } +#endif +} + +template +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128, "num_bits must be 128"); + load_128b(smem_ptr, gmem_ptr); +} + +template +__device__ __forceinline__ void pred_load(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + static_assert(num_bits == 128 || num_bits == 64 || num_bits == 32, + "num_bits must be 128, 64 or 32."); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 64) { + pred_load_64b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 32) { + pred_load_32b(smem_ptr, gmem_ptr, predicate); + } +} + +using b32_t = uint32_t; +using b64_t = uint2; +using b128_t = uint4; + +template +constexpr __host__ __device__ __forceinline__ uint32_t num_elems_per_128b() { + return sizeof(b128_t) / sizeof(T); +} + +struct smem_t { + // The base pointer. + b128_t* base; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) {} + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, + uint32_t j) { + if constexpr (inv_stride <= 1) { + return i * stride + (j ^ (i % 8)); + } else { + return i / inv_stride * 8 + ((j + (i % inv_stride) * stride)) ^ + ((i / inv_stride) % 8); + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_column(uint32_t offset, uint32_t step_idx) { + if constexpr (row_stride == 2) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } else if constexpr (row_stride == 4) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_row(uint32_t offset) { + if constexpr (row_stride == 2) { + static_assert(step_size == 16 || step_size % 32 == 0, + "Unsupported step size"); + if constexpr (step_size == 16) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 32 == 0 + return offset + step_size * row_stride; + } + } else if constexpr (row_stride == 4) { + static_assert(step_size == 8 || step_size % 16 == 0, + "Unsupported step size"); + if constexpr (step_size == 8) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 16 == 0 + return offset + step_size * row_stride; + } + } else { + static_assert(step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 8 == 0 + return offset + step_size * row_stride; + } + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_impl(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_trans_impl(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr, + bool predicate) { + b128_t* smem_ptr = base + offset; + pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr) { + b128_t* smem_ptr = base + offset; + load_128b(smem_ptr, + reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) { + *reinterpret_cast(gptr) = *(base + offset); + } +}; diff --git a/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh b/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh new file mode 100644 index 00000000000..8662ee298d2 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh @@ -0,0 +1,296 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_i8i8i32( + int* C, // 8 + uint32_t* A, // 4 + uint32_t* B) { // 4 + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + } else { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), + "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(C[4]), + "r"(C[5]), + "r"(C[6]), + "r"(C[7])); + } +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + float* C, uint32_t* A, uint32_t* B) { + if constexpr (mma_mode == MMAMode::kInit) { + if constexpr (std::is_same::value) { // fp16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } else { // bf16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } + } else { + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } + } +} + +template +__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { + static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type"); + uint32_t* s_u32 = (uint32_t*)(s); + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1006648320), + "r"(1006648320), + "f"(d[0]), + "f"(d[1])); + } else { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1065369472), + "r"(1065369472), + "f"(d[0]), + "f"(d[1])); + } +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/template_config.json b/custom_ops/gpu_ops/decode_unified_attention/template_config.json new file mode 100644 index 00000000000..d768c93a1ad --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/template_config.json @@ -0,0 +1,78 @@ +{ + "multiquery_attention_c8": { + "name": "decode_unified_attention_c8_kernel", + "function_name": "decode_unified_attention_c8_kernel", + "impl_file": "decode_unified_attention_c8_impl.cuh", + "template_params": [ + "T", + "CacheT", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_y", + "num_frags_z", + "is_scale_channel_wise", + "IsFP8", + "IsDynamicC8" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "CacheT": ["uint8_t"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_y": [8], + "num_frags_z": [1], + "is_scale_channel_wise": [0, 1], + "IsFP8": [0, 1], + "IsDynamicC8": [0, 1] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_unified_attention_c8", + "function_signature": "template __global__ void {function_name}{template_args}(AttentionParams{params_template_args} params);\n\n" + }, + "multiquery_attention_c16": { + "name": "decode_unified_attention_c16_kernel", + "function_name": "decode_unified_attention_c16_kernel", + "impl_file": "decode_unified_attention_c16_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_z", + "num_frags_y" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_z": [1], + "num_frags_y": [8] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_unified_attention_c16", + "function_signature": "template __global__ void {function_name}{template_args}(AttentionParams{params_template_args} params);\n\n" + } +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/utils.cuh b/custom_ops/gpu_ops/decode_unified_attention/utils.cuh new file mode 100644 index 00000000000..7111ad23fb7 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/utils.cuh @@ -0,0 +1,689 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include "helper.h" +#include "mem_util.cuh" + +#define NUM_WARPS_PER_BLOCK 4 +#define NUM_THREADS_PER_BLOCK 128 +#define kWarpSize 32 + +#define HOSTDEVICE __host__ __device__ + +/*-------------------------------------traits-----------------------------------------*/ +template +struct type_traits { + using paddle_type = T; + using phi_type = T; + using nv_type = T; + using nv2_type = T; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::float16; +// using nv_type = half; +// using nv2_type = half2; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::bfloat16; +// using nv_type = __nv_bfloat16; +// using nv2_type = __nv_bfloat162; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat16> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat162> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT8_E4M3FN; +// using phi_type = phi::dtype::float8_e4m3fn; +// using nv_type = __nv_fp8_e4m3; +// using nv2_type = __nv_fp8x2_e4m3; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8x2_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; +/*---------------------------------1. type + * traits--------------------------------------*/ + +/*---------------------------------2. fast + * convert--------------------------------------*/ +inline __device__ static void convert_fp8(half* result, + const uint32_t& source) { + printf("Do not support fp8 to half although it's very easy.\n"); +} + +inline __device__ static void convert_fp8(__nv_bfloat16* result, + const uint32_t& source) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + uint32_t dest0; + uint32_t dest1; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "mov.b32 {lo, hi}, %2;\n" + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" + "}\n" + : "=r"(dest0), "=r"(dest1) + : "r"(source)); + + ((nv_bfloat162*)(result))[0] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); + ((nv_bfloat162*)(result))[1] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif +} + +inline __device__ static void convert_int8( + half* result, const uint32_t& source) { // 4 int8 each time + uint32_t* fp16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(fp16_result_ptr[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(fp16_result_ptr[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); +} + +inline __device__ static void convert_int8( + __nv_bfloat16* result, const uint32_t& source) { // 4 int8 each time + uint32_t* bf16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; // (8388608.f + 128.f); + } + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +} +/*---------------------------------2. fast + * convert--------------------------------------*/ + +/*---------------------------------3. vector + * cast--------------------------------------*/ +template +__forceinline__ HOSTDEVICE void vec_cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = src[i]; + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(float* dst, + const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(half* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast( + float* dst, const nv_bfloat16* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(nv_bfloat16* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } +} +/*---------------------------------3. vector + * cast--------------------------------------*/ + +/*-------------------------------------4. + * func-----------------------------------------*/ +__forceinline__ HOSTDEVICE int div_up(int a, int b) { + return a / b + (a % b != 0); +} + +template +__inline__ __device__ T Rsqrt(T x); + +template <> +__inline__ __device__ float Rsqrt(float x) { + return rsqrt(x); +} + +template <> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} + +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, + uint32_t y) { + return (x > y) ? x - y : 0U; +} + +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +HOSTDEVICE __forceinline__ uint8_t QuantToC8(const T scale, + const T value, + const float max_bound, + const float min_bound) { + uint8_t eight_bits; + float quant_value; + if constexpr (is_need_kv_quant) { + quant_value = static_cast(scale * value); + } else { + quant_value = static_cast(value); + } + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + + if constexpr (IsFP8) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + quant_value = quant_value > 448.0f ? 448.0f : quant_value; + quant_value = quant_value < -448.0f ? -448.0f : quant_value; + auto tmp = static_cast<__nv_fp8_e4m3>(quant_value); + eight_bits = *(reinterpret_cast(&tmp)); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif + } else { + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + eight_bits = static_cast(quant_value + 128.0f); + } + return eight_bits; +} + +template +inline __device__ static void convert_c8(T* result, const uint32_t& source) { + if constexpr (IsFP8) { + convert_fp8(result, source); + } else { + convert_int8(result, source); + } +} + +template +inline __device__ void WelfordCombine1(T b_m2, T* m2) { + *m2 += b_m2; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { + *m2 = thread_m2; + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask); + WelfordCombine1(b_m2, m2); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) { + WelfordWarpReduce(thread_m2, m2); +} + +#define CHECK_CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ + << ") " << __FILE__ << ": line " << __LINE__ \ + << " at function " << STR(func) << std::endl; \ + return e; \ + } \ + } + +__device__ __forceinline__ float2 fast_float2_mul(const float2& a, + const float2& b) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, 0.0;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, 0.0;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + return res; +} + +__device__ __forceinline__ float2 fast_float2_fma(float2& a, + const float2& b, + const float2& c) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, %6;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, %7;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), + "f"(a.y), + "f"(b.x), + "f"(b.y), + "f"(c.x), + "f"(c.y) // 输入操作数 + ); + return res; +} + +// __device__ __forceinline__ float2 fast_bfloat162_fma(__nv_bfloat162& a_bf162, +// const __nv_bfloat162& b_bf162, const __nv_bfloat162& c_bf162) { +// // 使用向量化PTX指令同时处理x/y分量 +// asm volatile ( +// "{\n" +// " fma.rn.b16 %0, %2, %4, %0;\n" // res.x = a.x * b.x +// " fma.rn.b16 %1, %3, %5, %1;\n" // res.y = a.y * b.y +// "}" +// : "=r"(a_bf162.x), "=r"(a_bf162.y) // 输出操作数 +// : "r"(b_bf162.x), "r"(b_bf162.y), +// "r"(c_bf162.x), "r"(c_bf162.y) // 输入操作数 +// ); +// float2 res = __bfloat1622float2_rn(a_bf162); +// return res; +// } + +__device__ __forceinline__ float2 fast_float2_sub_expf(const float2& a, + const float2& b) { + float2 res; + // 使用向量化减法指令(PTX sub.rn.f32) + asm volatile( + "{\n" + " sub.f32 %0, %2, %4;\n" // res.x = a.x - b.x + " sub.f32 %1, %3, %5;\n" // res.y = a.y - b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + res.x = expf(res.x); + res.y = expf(res.y); + return res; +} + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = static_cast(ori_out_vec[i]); + printf("Fatal! Unimplemented StoreFunc for cascade append attention\n"); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + 127.0f * + static_cast((ori_out_vec[i] + shift_bias_vec[i]) * + smooth_weight_vec[i]) * + in_scale; + quant_value = rintf(quant_value); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + out_vec[i] = static_cast(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector<__nv_fp8_e4m3, VEC_SIZE>& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + quant_max_bound * static_cast(ori_out_vec[i]) * in_scale; + quant_value = quant_value > quant_max_bound ? quant_max_bound : quant_value; + quant_value = quant_value < quant_min_bound ? quant_min_bound : quant_value; + out_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = ori_out_vec[i]; + } +}; +/*-------------------------------------4. + * func-----------------------------------------*/ + +/*-----------------------------------5. + * dispatch---------------------------------------*/ +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("not support the head_dim"); \ + } \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 5) { \ + constexpr size_t GROUP_SIZE = 5; \ + __VA_ARGS__ \ + } else if (group_size == 6) { \ + constexpr size_t GROUP_SIZE = 6; \ + __VA_ARGS__ \ + } else if (group_size == 7) { \ + constexpr size_t GROUP_SIZE = 7; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size", group_size); \ + } + +#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_Q_TILE_SIZE( \ + group_size, max_tokens_per_batch, Q_TILE_SIZE, ...) \ + { \ + constexpr size_t Q_TILE_SIZE = 16; \ + __VA_ARGS__ \ + } + +#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ + if (causal) { \ + constexpr bool CAUSAL = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool CAUSAL = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCKSHAPE_Q_SYSTEM( \ + block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr size_t BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } + +#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \ + if (is_dynamic_cfp8) { \ + constexpr bool IsDynamicC8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IsDynamicC8 = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_IS_FP8(is_fp8, IS_FP8, ...) \ + if (is_fp8) { \ + constexpr bool IS_FP8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IS_FP8 = false; \ + __VA_ARGS__ \ + } + +struct AppendAttnMetaData { + int batch_size; + int block_size; + int q_num_heads; + int kv_num_heads; + int token_num; + int head_dims; + int head_dims_v; + int max_blocks_per_seq; + const int* mask_offset = nullptr; +}; + +template +struct AttentionParams { + T* __restrict__ qkv; + CacheT* __restrict__ cache_k; + CacheT* __restrict__ cache_v; + T* __restrict__ cache_k_scale; + T* __restrict__ cache_v_scale; + int* __restrict__ seq_lens_q; + int* __restrict__ seq_lens_kv; + int* __restrict__ block_indices; + int* __restrict__ num_blocks_ptr; + int* __restrict__ chunk_size_ptr; + int* __restrict__ cu_seqlens_q; + int* __restrict__ block_table; + int* __restrict__ mask_offset; + bool* __restrict__ attn_mask; + T* __restrict__ tmp_o; + float* __restrict__ tmp_m; + float* __restrict__ tmp_d; + int max_model_len; + int max_kv_len; + int max_blocks_per_seq; + float softmax_scale; + float quant_max_bound; + float quant_min_bound; + int num_blocks_x; + int attn_mask_len; + bool sliding_window; + int q_num_heads; + int kv_num_heads; + int max_num_chunks; + int max_tile_q; + int batch_size; + int token_num; + int head_dims; + int max_tokens_per_batch; +}; diff --git a/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu new file mode 100644 index 00000000000..7878e9926c5 --- /dev/null +++ b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu @@ -0,0 +1,326 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attn/decoder_write_cache_with_rope_kernel.h" +#include "append_attn/speculate_write_cache_with_rope_kernel.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + auto stream = qkv.stream(); + + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type_str == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + + if (max_just_dec_len_this_time > 0) { + if (speculate_decoder) { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } else { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } + } + return {qkv}; +} + +std::vector> DecoderWriteCacheWithRoPEInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_shape}; +} + +std::vector DecoderWriteCacheWithRoPEInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_dtype}; +} + +PD_BUILD_STATIC_OP(decoder_write_cache_with_rope) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "set_max_lengths", + paddle::Optional("rotary_embs"), + paddle::Optional("qkv_bias"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight")}) + .Outputs({"qkv_out"}) + .SetInplaceMap({{"qkv", "qkv_out"}}) + .Attrs({ + "rms_norm_eps: float", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "speculate_decoder: bool", + }) + .SetKernelFn(PD_KERNEL(DecoderWriteCacheWithRoPE)) + .SetInferShapeFn(PD_INFER_SHAPE(DecoderWriteCacheWithRoPEInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecoderWriteCacheWithRoPEInferDtype)); diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index b0ca5e2c0ce..dce65b97274 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -54,7 +54,6 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, PADDLE_ENFORCE(k_token_num == v_input.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(head_dim == 128, "Unmatched shape"); PADDLE_ENFORCE(batch_size > 0, "Unmatched shape"); - PADDLE_ENFORCE(batch_size == seq_len_encoder.dims()[0], "Unmatched shape"); PADDLE_ENFORCE(batch_size == cu_seq_k.dims()[0] - 1, "Unmatched shape"); constexpr int kBlockM = 128; diff --git a/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp b/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp index cb76da20d6a..277ed46f851 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp @@ -490,6 +490,23 @@ struct CollectiveMainloopAttn { softmax.rescale_o(tOrO, scores_scale); consumer_wait(pipeline_v, smem_pipe_read_v); + if (seq_len_k - n_block * kBlockN < kBlockN) { + int valid_k = seq_len_k - n_block * kBlockN; + auto sVt_this = sVt(_, _, smem_pipe_read_v.index()); + constexpr int kHdLo = decltype(get<0, 0>(shape(sVt_this)))::value; + constexpr int kHdHi = decltype(get<0, 1>(shape(sVt_this)))::value; + if (thread_idx >= valid_k && thread_idx < kBlockN) { +#pragma unroll + for (int hd_hi = 0; hd_hi < kHdHi; ++hd_hi) { +#pragma unroll + for (int hd_lo = 0; hd_lo < kHdLo; ++hd_lo) { + sVt_this(make_coord(make_coord(hd_lo, hd_hi), thread_idx)) = + Element(0); + } + } + } + cutlass::arch::fence_view_async_shared(); + } gemm( tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); diff --git a/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu new file mode 100644 index 00000000000..f25084076c4 --- /dev/null +++ b/custom_ops/gpu_ops/fused_cast_sigmoid_bias.cu @@ -0,0 +1,206 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +// Fused kernel: cast(input, cast_type) -> sigmoid -> scores, scores + bias -> +// scores_with_bias +// +// For each element (token i, expert j): +// scores[i][j] = OutT(sigmoid(float(input[i][j]))) +// scores_with_bias[i][j] = OutT(sigmoid(float(input[i][j])) + bias[j]) +// +// Input: input [num_tokens, num_experts] bf16/fp16/fp32 +// bias [num_experts] or [1, num_experts] fp32 +// Output: scores [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// scores_with_bias [num_tokens, num_experts] cast_type (fp32/fp16/bf16) +// +// Precision guarantee: +// All intermediate computations (cast, sigmoid, bias addition) are performed +// in float32, regardless of input/output types. The cast to OutT only happens +// at the final store. This matches the reference implementation: +// gate_fp32 = gate_out.cast("float32") +// scores_fp32 = sigmoid(gate_fp32) +// scores_with_bias_fp32 = scores_fp32 + bias // bias is always float32 +// scores = scores_fp32.cast(cast_type) +// scores_with_bias = scores_with_bias_fp32.cast(cast_type) +// +// When cast_type is "float32", the fused kernel is numerically identical to +// the reference. For fp16/bf16 output, the only precision loss comes from +// the final static_cast, equivalent to .cast() in the reference path. +// +// Note: bias is intentionally kept as float32 (not converted to OutT) to +// ensure the addition s + bias[j] is always computed in full float32 +// precision before the final downcast. + +template +__global__ void fused_cast_sigmoid_bias_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + for (int j = threadIdx.x; j < num_experts; j += blockDim.x) { + // All intermediate computation in float32 for precision + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + // s (float32) + bias[j] (float32) -> float32 addition, then downcast + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); + } +} + +// Vectorized version for better memory throughput +template +__global__ void fused_cast_sigmoid_bias_vec_kernel( + const InT* __restrict__ input, + const float* __restrict__ bias, // kept as float32 for full-precision add + OutT* __restrict__ scores, + OutT* __restrict__ scores_with_bias, + const int num_experts) { + const int64_t token_idx = blockIdx.x; + const int64_t offset = token_idx * num_experts; + + using in_vec_t = AlignedVector; + using out_vec_t = AlignedVector; + using bias_vec_t = AlignedVector; // float32 bias vectors + + const int vec_count = num_experts / kVecSize; + for (int idx = threadIdx.x; idx < vec_count; idx += blockDim.x) { + const int base = idx * kVecSize; + in_vec_t in_vec; + bias_vec_t bias_vec; + Load(input + offset + base, &in_vec); + Load(bias + base, &bias_vec); + + out_vec_t s_vec, sb_vec; +#pragma unroll + for (int i = 0; i < kVecSize; ++i) { + // All intermediate computation in float32 for precision + float val = static_cast(in_vec[i]); + float s = 1.0f / (1.0f + expf(-val)); + // s (float32) + bias_vec[i] (float32) -> float32 addition, then downcast + s_vec[i] = static_cast(s); + sb_vec[i] = static_cast(s + bias_vec[i]); + } + + Store(s_vec, scores + offset + base); + Store(sb_vec, scores_with_bias + offset + base); + } + + // Handle remaining elements (same float32 precision guarantee) + const int remaining_start = vec_count * kVecSize; + for (int j = remaining_start + threadIdx.x; j < num_experts; + j += blockDim.x) { + float val = static_cast(input[offset + j]); + float s = 1.0f / (1.0f + expf(-val)); + scores[offset + j] = static_cast(s); + scores_with_bias[offset + j] = static_cast(s + bias[j]); + } +} + +static paddle::DataType ParseCastType(const std::string& cast_type) { + if (cast_type == "float32") return paddle::DataType::FLOAT32; + if (cast_type == "float16") return paddle::DataType::FLOAT16; + if (cast_type == "bfloat16") return paddle::DataType::BFLOAT16; + PD_THROW("Unsupported cast_type: " + cast_type + + ". Only float32, float16, bfloat16 are supported."); +} + +std::vector FusedCastSigmoidBias(const paddle::Tensor& input, + const paddle::Tensor& bias, + std::string cast_type) { + auto input_shape = input.shape(); + PD_CHECK(input_shape.size() == 2, + "input must be 2D [num_tokens, num_experts]"); + auto bias_shape = bias.shape(); + // Support both [num_experts] and [1, num_experts] bias shapes + PD_CHECK( + bias_shape.size() == 1 || (bias_shape.size() == 2 && bias_shape[0] == 1), + "bias must be 1D [num_experts] or 2D [1, num_experts]"); + + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + int64_t bias_numel = (bias_shape.size() == 1) ? bias_shape[0] : bias_shape[1]; + PD_CHECK(bias_numel == num_experts, "bias size must match num_experts"); + PD_CHECK(bias.dtype() == paddle::DataType::FLOAT32, + "bias must be float32, got ", + bias.dtype()); + + auto place = input.place(); + auto stream = input.stream(); + auto out_dtype = ParseCastType(cast_type); + + auto scores = paddle::empty({num_tokens, num_experts}, out_dtype, place); + auto scores_with_bias = + paddle::empty({num_tokens, num_experts}, out_dtype, place); + + if (num_tokens == 0) { + return {scores, scores_with_bias}; + } + + dim3 grid(num_tokens); + int block_size = std::min(static_cast(1024), num_experts); + // Round up to warp size + block_size = ((block_size + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + dim3 block(block_size); + + DISPATCH_FLOAT_FP6_DTYPE(input.dtype(), in_scalar_t, { + DISPATCH_FLOAT_FP6_DTYPE(out_dtype, out_scalar_t, { + constexpr int kVecSize = 16 / sizeof(in_scalar_t); + if (num_experts % kVecSize == 0 && num_experts >= kVecSize) { + fused_cast_sigmoid_bias_vec_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } else { + fused_cast_sigmoid_bias_kernel + <<>>(input.data(), + bias.data(), + scores.data(), + scores_with_bias.data(), + num_experts); + } + }); + }); + + return {scores, scores_with_bias}; +} + +std::vector FusedCastSigmoidBiasInferDtype( + const paddle::DataType& input_dtype, + const paddle::DataType& bias_dtype, + std::string cast_type) { + auto out_dtype = ParseCastType(cast_type); + return {out_dtype, out_dtype}; +} + +std::vector> FusedCastSigmoidBiasInferShape( + const std::vector& input_shape, + const std::vector& bias_shape) { + return {input_shape, input_shape}; +} + +PD_BUILD_STATIC_OP(fused_cast_sigmoid_bias) + .Inputs({"input", "bias"}) + .Outputs({"scores", "scores_with_bias"}) + .Attrs({"cast_type: std::string"}) + .SetKernelFn(PD_KERNEL(FusedCastSigmoidBias)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedCastSigmoidBiasInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedCastSigmoidBiasInferDtype)); diff --git a/custom_ops/gpu_ops/fused_rotary_position_encoding.cu b/custom_ops/gpu_ops/fused_rotary_position_encoding.cu index 05c94b60b74..8ec225d5c29 100644 --- a/custom_ops/gpu_ops/fused_rotary_position_encoding.cu +++ b/custom_ops/gpu_ops/fused_rotary_position_encoding.cu @@ -53,9 +53,13 @@ __global__ void apply_rotary_embedding_kernel( const int64_t key_stride, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; + const int head_size, + const int num_tokens) { // 新增 num_tokens 参数用于边界检查 + + // 用2D grid表示token_idx,突破65535限制 + const int token_idx = blockIdx.x + blockIdx.y * gridDim.x; + if (token_idx >= num_tokens) return; // 边界保护 + int pos = position_ids[token_idx]; const T* cache_ptr = cos_sin_cache + pos * rot_dim; @@ -99,13 +103,13 @@ void FusedRotaryPositionEncoding( int64_t query_stride = num_heads * head_size; int64_t key_stride = num_kv_heads * head_size; - if (num_tokens > 65535) { - PD_THROW( - "apply_rotary_embedding_kernel launch failed when num_tokens > 65535."); - } - - dim3 grid(num_tokens); + // 拆成2D grid:每维最大65535,总计支持 65535*65535 >> 1024*1024 + constexpr int MAX_GRID_X = 65535; + int grid_x = std::min(num_tokens, MAX_GRID_X); + int grid_y = (num_tokens + MAX_GRID_X - 1) / MAX_GRID_X; + dim3 grid(grid_x, grid_y); dim3 block(std::min(num_heads * rot_dim / 2, 512)); + PD_DISPATCH_FLOATING_AND_HALF_TYPES( query.dtype(), "apply_rotary_embedding_kernel", [&] { if (is_neox) { @@ -119,7 +123,8 @@ void FusedRotaryPositionEncoding( key_stride, num_heads, num_kv_heads, - head_size); + head_size, + num_tokens); } else { apply_rotary_embedding_kernel <<>>(query.data(), @@ -131,7 +136,8 @@ void FusedRotaryPositionEncoding( key_stride, num_heads, num_kv_heads, - head_size); + head_size, + num_tokens); } }); } diff --git a/custom_ops/gpu_ops/get_attn_mask_q.cu b/custom_ops/gpu_ops/get_attn_mask_q.cu index 4ee814178bc..a485d04f6bc 100644 --- a/custom_ops/gpu_ops/get_attn_mask_q.cu +++ b/custom_ops/gpu_ops/get_attn_mask_q.cu @@ -24,7 +24,7 @@ __global__ void get_attn_mask_q_kernel( const int max_batch_size) { constexpr int VecSize = 4; const uint32_t tid = threadIdx.x, bid = blockIdx.x; - int startend_row_vec[4]; + int startend_row_vec[2]; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif @@ -49,9 +49,9 @@ __global__ void get_attn_mask_q_kernel( const uint32_t cache_k_idx = cu_seqlens_k_idx - kv_start; startend_row_vec[0] = this_batch_q_end; - startend_row_vec[1] = cu_seqlens_q[max_batch_size]; - startend_row_vec[2] = 0; - startend_row_vec[3] = this_batch_q_end; + // startend_row_vec[1] = cu_seqlens_q[max_batch_size]; + // startend_row_vec[2] = 0; + startend_row_vec[1] = this_batch_q_end; for (int this_batch_q_idx = this_batch_q_start; this_batch_q_idx < this_batch_q_end; ++this_batch_q_idx) { @@ -62,14 +62,14 @@ __global__ void get_attn_mask_q_kernel( : this_batch_q_idx - this_batch_q_start + kv_len - (this_batch_q_len); if (cache_k_idx <= append_mask_k_end) { - startend_row_vec[3] = min(startend_row_vec[3], this_batch_q_idx); + startend_row_vec[1] = min(startend_row_vec[1], this_batch_q_idx); // 可提前跳出循环 break; } } - reinterpret_cast(startend_row_indices_ptr + - cu_seqlens_k_idx * 4)[0] = - reinterpret_cast(startend_row_vec)[0]; + reinterpret_cast(startend_row_indices_ptr + + cu_seqlens_k_idx * 2)[0] = + reinterpret_cast(startend_row_vec)[0]; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); @@ -82,7 +82,7 @@ std::vector get_attn_mask_q( const paddle::optional& attn_mask_kv, const int kv_token_num) { paddle::Tensor attn_mask_startend_row_indices = GetEmptyTensor( - {1, 1, kv_token_num, 4}, paddle::DataType::INT32, cu_seqlens_k.place()); + {1, 1, kv_token_num, 2}, paddle::DataType::INT32, cu_seqlens_k.place()); const int max_batch_size = cu_seqlens_k.dims()[0] - 1; constexpr int block_size = 512; int grid_size = div_up(kv_token_num, block_size); @@ -123,7 +123,7 @@ std::vector> GetAttnMaskQInferShape( const std::vector& cu_seqlens_k_shape, const paddle::optional>& attn_mask_kv_shape, const int kv_token_num) { - return {{1, 1, kv_token_num, 4}}; + return {{1, 1, kv_token_num, 2}}; } PD_BUILD_STATIC_OP(get_attn_mask_q) diff --git a/custom_ops/gpu_ops/get_output_msg_with_topk.cc b/custom_ops/gpu_ops/get_output_msg_with_topk.cc index e70f7c2c24d..363274d6aac 100644 --- a/custom_ops/gpu_ops/get_output_msg_with_topk.cc +++ b/custom_ops/gpu_ops/get_output_msg_with_topk.cc @@ -88,13 +88,17 @@ void GetOutputTopK(const paddle::Tensor& x, return; } - int bsz = msg_rcv.mtext[1]; + // Unpack bsz (low 16 bits) and actual_topk (high 16 bits) from mtext[1]. + // This matches the packing in save_output_msg_with_topk.cc: + // mtext[1] = bsz | (max_num_logprobs << 16) + int bsz = msg_rcv.mtext[1] & 0xFFFF; + int actual_topk = (msg_rcv.mtext[1] >> 16) & 0xFFFF; out_data[0] = (int64_t)msg_rcv.mtext[0]; - out_data[1] = (int64_t)msg_rcv.mtext[1]; + out_data[1] = (int64_t)msg_rcv.mtext[1]; // keep packed value; Python unpacks for (int i = 0; i < bsz; i++) { - for (int j = 0; j < k + 1; j++) { - const int64_t offset = i * (K + 1) + j; + for (int j = 0; j < actual_topk; j++) { + const int64_t offset = i * actual_topk + j; out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2]; scores_data[offset] = msg_rcv.mtext_f[offset]; } diff --git a/custom_ops/gpu_ops/get_position_ids.cu b/custom_ops/gpu_ops/get_position_ids.cu new file mode 100644 index 00000000000..3c04332934f --- /dev/null +++ b/custom_ops/gpu_ops/get_position_ids.cu @@ -0,0 +1,67 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__global__ void GetPositionIdsKernel(const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + const int* __restrict__ seq_lens_this_time, + int* __restrict__ position_ids, + const int bsz) { + int current_bid = threadIdx.x; + if (current_bid >= bsz) return; + + // Caculate the offset of current batch in the position_ids buffer + int buffer_offset = 0; + for (int i = 0; i < current_bid; i++) { + buffer_offset += seq_lens_this_time[i]; + } + + // Caculate the token offset in the current batch + int token_offset = seq_lens_decoder[current_bid]; + int token_num_this_batch = seq_lens_this_time[current_bid]; + if (token_num_this_batch == 0) return; + +// Write position ids for current batch +#pragma unroll + for (int i = 0; i < token_num_this_batch; i++) { + position_ids[buffer_offset + i] = token_offset + i; + } +} + +void GetPositionIds(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& position_ids) { + const int bsz = seq_lens_this_time.shape()[0]; + + GetPositionIdsKernel<<<1, bsz, 0, position_ids.stream()>>>( + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + const_cast(position_ids.data()), + bsz); +} + +PD_BUILD_STATIC_OP(get_position_ids) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "position_ids", + }) + .Outputs({"position_ids_out"}) + .SetInplaceMap({{"position_ids", "position_ids_out"}}) + .SetKernelFn(PD_KERNEL(GetPositionIds)); diff --git a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu b/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu deleted file mode 100644 index 946c9754072..00000000000 --- a/custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "helper.h" -#include "paddle/extension.h" - -__global__ void GetPositionIdsAndMaskEncoderBatchKernel( - const int* seq_lens_encoder, // [bsz] 每个批次的 encoder 长度 - const int* seq_lens_decoder, // [bsz] 每个批次的 decoder 长度 - const int* seq_lens_this_time, - int* position_ids, // 输出的一维 position_ids - int* mask_encoder_batch, - const int bsz) { // 批次大小 - // 当前线程索引(每个线程对应一个批次) - int tid = threadIdx.x; - if (tid >= bsz) return; - - // 动态计算当前批次的偏移量 - int offset = 0; - for (int i = 0; i < tid; i++) { - offset += seq_lens_encoder[i]; - if (seq_lens_decoder[i] > 0) { - offset += seq_lens_this_time[i]; - } - } - - // 当前批次的 encoder 和 decoder 长度 - int encoder_len = seq_lens_encoder[tid]; - int decoder_len = seq_lens_decoder[tid]; - int seq_len_this_time = seq_lens_this_time[tid]; - - // 写入 encoder 的 position_ids - for (int i = 0; i < encoder_len; i++) { - position_ids[offset + i] = i; - mask_encoder_batch[offset + i] = 1; - } - offset += encoder_len; - - // 写入 decoder 的 position_ids - if (decoder_len > 0) { - for (int i = 0; i < seq_len_this_time; i++) { - position_ids[offset + i] = decoder_len + i; // 使用 decoder 长度本身 - mask_encoder_batch[offset + i] = 0; - } - } -} - -void GetPositionIdsAndMaskEncoderBatch( - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& seq_lens_decoder, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& position_ids, - const paddle::Tensor& mask_encoder_batch) { - const int bsz = seq_lens_this_time.shape()[0]; - - GetPositionIdsAndMaskEncoderBatchKernel<<<1, bsz, 0, position_ids.stream()>>>( - seq_lens_encoder.data(), - seq_lens_decoder.data(), - seq_lens_this_time.data(), - const_cast(position_ids.data()), - const_cast(mask_encoder_batch.data()), - bsz); -} - -PD_BUILD_STATIC_OP(get_position_ids_and_mask_encoder_batch) - .Inputs({"seq_lens_encoder", - "seq_lens_decoder", - "seq_lens_this_time", - "position_ids", - "mask_encoder_batch"}) - .Outputs({"position_ids_out", "mask_encoder_batch_out"}) - .SetInplaceMap({{"position_ids", "position_ids_out"}, - {"mask_encoder_batch", "mask_encoder_batch_out"}}) - .SetKernelFn(PD_KERNEL(GetPositionIdsAndMaskEncoderBatch)); diff --git a/custom_ops/gpu_ops/get_position_ids_and_slot_mapping.cu b/custom_ops/gpu_ops/get_position_ids_and_slot_mapping.cu new file mode 100644 index 00000000000..5c57a071461 --- /dev/null +++ b/custom_ops/gpu_ops/get_position_ids_and_slot_mapping.cu @@ -0,0 +1,108 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +__global__ void GetPositionIdsAndSlotMappingKernel( + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + const int* __restrict__ seq_lens_this_time, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ block_tables, + const int bsz, + const int max_num_blocks, + const int block_size, + int64_t* __restrict__ position_ids, + int64_t* __restrict__ slot_mapping) { + int current_bid = threadIdx.x; + if (current_bid >= bsz) return; + + // Calculate the offset of current batch in the position_ids buffer + int buffer_offset = 0; + for (int i = 0; i < current_bid; i++) { + buffer_offset += seq_lens_this_time[i]; + } + + // Calculate the token offset in the current batch + int token_offset = seq_lens_decoder[current_bid]; + int token_num_this_batch = seq_lens_this_time[current_bid]; + if (token_num_this_batch == 0) return; + + // Write position ids and slot mapping for current batch +#pragma unroll + for (int i = 0; i < token_num_this_batch; i++) { + int pos_id = token_offset + i; + int idx = buffer_offset + i; + + // Write position_id + position_ids[idx] = pos_id; + + // Calculate slot mapping directly + int block_idx = pos_id / block_size; + int block_offset = pos_id % block_size; + int batch_id = batch_id_per_token[idx]; + + // Get block_id from block_tables + int block_id = block_tables[batch_id * max_num_blocks + block_idx]; + + // Calculate slot mapping + slot_mapping[idx] = static_cast( + static_cast(block_id) * block_size + block_offset); + } +} + +void GetPositionIdsAndSlotMapping(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& block_tables, + const paddle::Tensor& position_ids, + const paddle::Tensor& slot_mapping, + const int block_size) { + const int bsz = seq_lens_this_time.shape()[0]; + const int total_token_num = position_ids.shape()[0]; + const int max_num_blocks = block_tables.shape()[1]; + + GetPositionIdsAndSlotMappingKernel<<<1, + bsz, + 0, + seq_lens_this_time.stream()>>>( + seq_lens_encoder.data(), + seq_lens_decoder.data(), + seq_lens_this_time.data(), + batch_id_per_token.data(), + block_tables.data(), + bsz, + max_num_blocks, + block_size, + const_cast(position_ids.data()), + const_cast(slot_mapping.data())); +} + +PD_BUILD_STATIC_OP(get_position_ids_and_slot_mapping) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "block_tables", + "position_ids", + "slot_mapping", + }) + .Attrs({"block_size: int"}) + .Outputs({"position_ids_out", "slot_mapping_out"}) + .SetInplaceMap({{"position_ids", "position_ids_out"}, + {"slot_mapping", "slot_mapping_out"}}) + .SetKernelFn(PD_KERNEL(GetPositionIdsAndSlotMapping)); diff --git a/custom_ops/gpu_ops/grouped_topk_kernels.cu b/custom_ops/gpu_ops/grouped_topk_kernels.cu new file mode 100644 index 00000000000..ef5ed8533f0 --- /dev/null +++ b/custom_ops/gpu_ops/grouped_topk_kernels.cu @@ -0,0 +1,786 @@ + +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "helper.h" + +namespace cg = cooperative_groups; + +constexpr unsigned FUSED_FULL_WARP_MASK = 0xffffffff; + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { + return __float2bfloat16(val); +} + +template <> +__device__ inline float cuda_cast(__half val) { + return __half2float(val); +} + +template <> +__device__ inline __half cuda_cast<__half, float>(float val) { + return __float2half(val); +} + +// Numerically stable sigmoid via tanh: σ(x) = 0.5 * tanh(0.5*x) + 0.5 +template +__device__ __forceinline__ T sigmoid_device(T x) { + float xf = cuda_cast(x); + return cuda_cast(0.5f * tanhf(0.5f * xf) + 0.5f); +} + +// Sigmoid matching fused_cast_sigmoid_bias: 1 / (1 + exp(-x)). +// Must use the same formula to get bit-identical results when comparing +// against the fused_cast_sigmoid_bias + noaux_tc path. +template +__device__ __forceinline__ float sigmoid_to_float(InT x) { + float xf = cuda_cast(x); + return 1.0f / (1.0f + expf(-xf)); +} + +template +__device__ inline T neg_inf() { + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + +template +__device__ inline bool is_finite_val(T val) { +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) + return cuda::std::isfinite(val); +#else + return isfinite(cuda_cast(val)); +#endif +} + +namespace warp_topk_fused { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) return 0; + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, + T baseline, + idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + return res; +} + +template +struct BitonicMerge { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) + is_better = is_better_than( + val, other_val, idx_arr[i], idx_arr[other_i]); + else + is_better = is_better_than(val, other_val); + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + T other = __shfl_xor_sync(FUSED_FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = + __shfl_xor_sync(FUSED_FULL_WARP_MASK, *idx_arr, stride); + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + else + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FUSED_FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FUSED_FULL_WARP_MASK, idx, stride); + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); + else + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + __device__ __forceinline__ idxT get_idx(int i = 0) const { + return idx_arr_[i]; + } + __device__ __forceinline__ T get_val(int i = 0) const { return val_arr_[i]; } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + int const lane_; + idxT const k_; + T const dummy_; +}; + +// WarpSelect WITHOUT __syncthreads() in done() — safe when only one warp is +// active. +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_idx_(0), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + else + do_add = is_better_than(val, k_th_); + + uint32_t mask = __ballot_sync(FUSED_FULL_WARP_MASK, do_add); + if (mask == 0) return; + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + // NOTE: no __syncthreads() here — callers must sync externally if needed. + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync( + FUSED_FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) + k_th_idx_ = __shfl_sync( + FUSED_FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + T& old = val_arr_[max_arr_len_ - 1]; + bool is_better; + if constexpr (is_stable) + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + else + is_better = is_better_than(val, old); + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + BitonicMerge::merge( + val_arr_, idx_arr_); + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; + +} // namespace warp_topk_fused + +// --------------------------------------------------------------------------- +// Fused kernel: group-score computation + group selection + expert topk +// + sparse scores write-back, in one kernel launch. +// +// gridDim = num_tokens (one block per token) +// blockDim = n_group * WARP_SIZE (one warp per group) +// --------------------------------------------------------------------------- +template +__global__ void grouped_topk_fused_kernel( + float* scores, // output: sparse routing weights [num_tokens, num_experts] + float* topk_values, // output: topk routing weights [num_tokens, topk] + IdxT* topk_indices, // output: topk expert indices [num_tokens, topk] + InT const* gating_output, // input: raw logits (float or bf16) + // [num_tokens, num_experts] + float const* e_score_correction_bias, // input: bias [num_experts] + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double routed_scaling_factor) { + int32_t const token_id = static_cast(blockIdx.x); + if (token_id >= static_cast(num_tokens)) return; + + int32_t const warp_id = threadIdx.x / WARP_SIZE; + int32_t const lane_id = threadIdx.x % WARP_SIZE; + int32_t const n_group_i32 = static_cast(n_group); + int32_t const topk_group_i32 = static_cast(topk_group); + int32_t const topk_i32 = static_cast(topk); + int32_t const num_warps = blockDim.x / WARP_SIZE; + + if (warp_id >= n_group_i32 || num_warps < n_group_i32) return; + + int32_t const num_experts_per_group = + static_cast(num_experts) / n_group_i32; + int32_t const align_epg = warp_topk_fused::round_up_to_multiple_of( + num_experts_per_group); + + InT const* gate_token = gating_output + (int64_t)token_id * num_experts; + float* scores_token = scores + (int64_t)token_id * num_experts; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + // smem layout: [val_staging (256B-aligned) | idx_staging | (16B pad) | + // s_group_scores] + extern __shared__ char smem_buf[]; + size_t const val_aligned = warp_topk_fused::round_up_to_multiple_of<256>( + static_cast(num_warps) * WARP_SIZE * sizeof(float)); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + uintptr_t ptr = + (reinterpret_cast(smem_buf + val_aligned + idx_bytes) + 15) & + ~static_cast(15); + float* s_group_scores = reinterpret_cast(ptr); + float* s_topk_value = + reinterpret_cast(smem_buf); // val_staging (256B-aligned) + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + // ------------------------------------------------------------------ + // Phase 1 (all warps): compute group score = top2 sum of (gate + bias) + // ------------------------------------------------------------------ + { + int32_t const offset = warp_id * num_experts_per_group; + InT const* gate_g = gate_token + offset; + float const* bias_g = e_score_correction_bias + offset; + + float largest = neg_inf(); + float second_largest = neg_inf(); + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + float val = sigmoid_to_float(gate_g[i]) + bias_g[i]; + if (val > largest) { + second_largest = largest; + largest = val; + } else if (val > second_largest) { + second_largest = val; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) + largest = sigmoid_to_float(gate_g[i]) + bias_g[i]; + } + __syncwarp(); + float max1 = cg::reduce(tile, largest, cg::greater()); + float max2 = max1; + int cnt = __popc(__ballot_sync(FUSED_FULL_WARP_MASK, largest == max1)); + if (cnt == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + if (lane_id == 0) s_group_scores[warp_id] = max1 + max2; + } + + __syncthreads(); // __syncwarp() maybe better? + + // ------------------------------------------------------------------ + // Phase 2 (warp0 only): group selection → expert selection → output + // ------------------------------------------------------------------ + if (warp_id != 0) return; + + float value = neg_inf(); + float topk_group_value = neg_inf(); + int32_t num_equalto_topkth_group; + if (token_id < num_tokens) { + int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group && (isfinite(s_group_scores[lane_id]))) { + value = s_group_scores[lane_id]; + } + + int neg_inf_num = WARP_SIZE - n_group; + int last_neg_inf_num = 0; + // Use loop to find the largset top_group + while (neg_inf_num < want_neg_inf_num) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = neg_inf(); + } + last_neg_inf_num = neg_inf_num; + + neg_inf_num = __popc( + __ballot_sync(FUSED_FULL_WARP_MASK, (value == neg_inf()))); + } + // There is a possible case: + // may have many different group holding the same score! + // but we only accept some of them! + num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num; + } + __syncwarp(); + + warp_topk_fused::WarpSelect + queue((int32_t)topk, neg_inf()); + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = (topk_group_value != neg_inf()); + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((s_group_scores[i_group] > topk_group_value) || + ((s_group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_epg; i += WARP_SIZE) { + float candidates = neg_inf(); + if (i < num_experts_per_group) { + float biased = sigmoid_to_float(gate_token[offset + i]) + + e_score_correction_bias[offset + i]; + if (is_finite_val(biased)) candidates = biased; + } + queue.add(candidates, offset + i); + } + if (s_group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + } + + float topk_sum = 1e-20; + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk_fused::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + int32_t idx = i / WARP_SIZE; + float value = + i < topk ? sigmoid_to_float(gate_token[queue.get_idx(idx)]) : 0.0f; + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += cg::reduce(tile, value, cg::plus()); + } + } + __syncwarp(); + + if (token_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + scores_token[i] = 0; + } + } + __syncwarp(); + + topk_values += (int64_t)token_id * topk; + topk_indices += (int64_t)token_id * topk; + if (token_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = s_topk_value[i] / topk_sum * routed_scaling_factor; + } else { + value = s_topk_value[i] * routed_scaling_factor; + } + int32_t idx = i / WARP_SIZE; // topk may be bigger than WARP_SIZE + scores_token[queue.get_idx(idx)] = value; + topk_indices[i] = queue.get_idx(idx); + topk_values[i] = value; + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + int32_t idx = i / WARP_SIZE; + topk_indices[i] = queue.get_idx(idx); + topk_values[i] = static_cast(1.0f / topk); + } + } + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +// --------------------------------------------------------------------------- +// Launch wrapper +// --------------------------------------------------------------------------- +template +void invokeFusedNoAuxTc(InT* gating_output, + float* e_score_correction_bias, + float* scores, + float* topk_values, + IdxT* topk_indices, + int64_t const num_tokens, + int64_t const num_experts, + int64_t const n_group, + int64_t const topk_group, + int64_t const topk, + bool const renormalize, + double const routed_scaling_factor, + cudaStream_t const stream) { + auto* kernel = &grouped_topk_fused_kernel; + + // blockDim = n_group * WARP_SIZE (one warp per group) + int32_t const num_warps = static_cast(n_group); + + // smem = WarpSelect staging (float) + 16B pad + group_scores buffer (float) + size_t const val_aligned = warp_topk_fused::round_up_to_multiple_of<256>( + static_cast(num_warps) * WARP_SIZE * sizeof(float)); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + size_t const extra_bytes = 16 + static_cast(n_group) * sizeof(float); + size_t const smem_bytes = val_aligned + idx_bytes + extra_bytes; + + cudaLaunchConfig_t config; + config.gridDim = static_cast(num_tokens); + config.blockDim = static_cast(n_group) * WARP_SIZE; + config.dynamicSmemBytes = smem_bytes; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + + cudaLaunchKernelEx(&config, + kernel, + scores, + topk_values, + topk_indices, + gating_output, + e_score_correction_bias, + num_tokens, + num_experts, + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor); +} + +#define INSTANTIATE_FUSED_NOAUX_TC(InT, IdxT) \ + template void invokeFusedNoAuxTc( \ + InT * gating_output, \ + float* e_score_correction_bias, \ + float* scores, \ + float* topk_values, \ + IdxT* topk_indices, \ + int64_t const num_tokens, \ + int64_t const num_experts, \ + int64_t const n_group, \ + int64_t const topk_group, \ + int64_t const topk, \ + bool const renormalize, \ + double const routed_scaling_factor, \ + cudaStream_t const stream); + +INSTANTIATE_FUSED_NOAUX_TC(float, int64_t); +INSTANTIATE_FUSED_NOAUX_TC(__nv_bfloat16, int64_t); +INSTANTIATE_FUSED_NOAUX_TC(__half, int64_t); + +// --------------------------------------------------------------------------- +// Paddle op wrapper +// --------------------------------------------------------------------------- +std::vector grouped_topk( + paddle::Tensor& gating_output, + paddle::Tensor& e_score_correction_bias, + int n_group, + int topk_group, + int topk, + bool renormalize, + float routed_scaling_factor) { + auto input_shape = gating_output.shape(); + PD_CHECK(input_shape.size() == 2); + int64_t num_tokens = input_shape[0]; + int64_t num_experts = input_shape[1]; + auto place = gating_output.place(); + PD_CHECK(n_group <= 32, "grouped_topk: n_group must be <= 32"); + PD_CHECK(topk <= 32, "grouped_topk: topk must be <= WARP_SIZE (32)"); + + // Outputs are always float32 regardless of input dtype + auto scores = paddle::empty( + {num_tokens, num_experts}, paddle::DataType::FLOAT32, place); + auto topk_values = + paddle::empty({num_tokens, topk}, paddle::DataType::FLOAT32, place); + auto topk_indices = + paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place); + + auto stream = gating_output.stream(); + auto dtype = gating_output.dtype(); + + float* scores_ptr = reinterpret_cast(scores.data()); + float* topk_values_ptr = reinterpret_cast(topk_values.data()); + int64_t* topk_idx_ptr = + reinterpret_cast(topk_indices.data()); + float* bias_ptr = + reinterpret_cast(e_score_correction_bias.data()); + + if (dtype == paddle::DataType::BFLOAT16) { + invokeFusedNoAuxTc<__nv_bfloat16, int64_t>( + reinterpret_cast<__nv_bfloat16*>( + gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } else if (dtype == paddle::DataType::FLOAT16) { + invokeFusedNoAuxTc<__half, int64_t>( + reinterpret_cast<__half*>(gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } else { + PD_CHECK( + dtype == paddle::DataType::FLOAT32, + "grouped_topk: gating_output must be float32, float16, or bfloat16"); + invokeFusedNoAuxTc( + reinterpret_cast(gating_output.data()), + bias_ptr, + scores_ptr, + topk_values_ptr, + topk_idx_ptr, + num_tokens, + num_experts, + static_cast(n_group), + static_cast(topk_group), + static_cast(topk), + renormalize, + static_cast(routed_scaling_factor), + stream); + } + + return {scores, topk_values, topk_indices}; +} + +std::vector GroupedTopkInferDtype( + const paddle::DataType& /*gating_output_dtype*/, + const paddle::DataType& /*e_score_correction_bias_dtype*/) { + // Outputs are always float32: cast is fused into the kernel. + return {paddle::DataType::FLOAT32, + paddle::DataType::FLOAT32, + paddle::DataType::INT64}; +} + +std::vector> GroupedTopkInferShape( + const std::vector& gating_output_shape, + const std::vector&, + const int topk) { + auto num_tokens = gating_output_shape[0]; + auto num_experts = gating_output_shape[1]; + return {{num_tokens, num_experts}, {num_tokens, topk}, {num_tokens, topk}}; +} + +PD_BUILD_STATIC_OP(grouped_topk) + .Inputs({"gating_output", "e_score_correction_bias"}) + .Outputs({"output_tensor", "topk_values", "topk_indices"}) + .Attrs({"n_group: int", + "topk_group: int", + "topk: int", + "renormalize: bool", + "routed_scaling_factor: float"}) + .SetKernelFn(PD_KERNEL(grouped_topk)) + .SetInferShapeFn(PD_INFER_SHAPE(GroupedTopkInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GroupedTopkInferDtype)); diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 83f3ad1077d..cb8c2e3e623 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -73,6 +73,8 @@ namespace cub = hipcub; using json = nlohmann::json; #endif +#define CEILDIV(a, b) (((a + b - 1) / b)) + #define CUDA_CHECK(call) \ do { \ const cudaError_t error_code = call; \ diff --git a/custom_ops/gpu_ops/merge_prefill_decode_output.cu b/custom_ops/gpu_ops/merge_prefill_decode_output.cu index a57d72bdf3a..b158ca7f0e8 100644 --- a/custom_ops/gpu_ops/merge_prefill_decode_output.cu +++ b/custom_ops/gpu_ops/merge_prefill_decode_output.cu @@ -44,13 +44,49 @@ __global__ void FillEncoderDecoderResKernel(T *encoder_res_data, return; } - const int load_idx = - ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim + land_id * 4; + const int base_idx = + ((cu_seq_q[bidb] + token_id) * head_num + bidh) * head_dim; - *reinterpret_cast(encoder_res_data + load_idx) = - *reinterpret_cast(decoder_res_data + load_idx); + if (head_dim == 128) { + const int load_idx = base_idx + land_id * 4; + *reinterpret_cast(encoder_res_data + load_idx) = + *reinterpret_cast(decoder_res_data + load_idx); + } else if (head_dim == 192) { + const int load_idx = base_idx + land_id * 4; + *reinterpret_cast(encoder_res_data + load_idx) = + *reinterpret_cast(decoder_res_data + load_idx); + if (land_id < 16) { + *reinterpret_cast(encoder_res_data + load_idx + 128) = + *reinterpret_cast(decoder_res_data + load_idx + 128); + } + } else if (head_dim == 256) { + // float4 = 单条LDG.128,性能最优 + const int load_idx = base_idx + land_id * 8; + *reinterpret_cast(encoder_res_data + load_idx) = + *reinterpret_cast(decoder_res_data + load_idx); + } } +#define LAUNCH_KERNEL(T, WARPS) \ + FillEncoderDecoderResKernel \ + <<>>( \ + const_cast(encoder_res.data()), \ + const_cast(decoder_res.data()), \ + seq_lens_encoder.data(), \ + seq_lens_decoder.data(), \ + seq_lens_this_time.data(), \ + cu_seq_q.data(), \ + head_num, \ + head_dim) + +#define LAUNCH_KERNEL_BY_HEAD_DIM(T) \ + if (head_dim == 128) \ + LAUNCH_KERNEL(T, 4); \ + else if (head_dim == 192) \ + LAUNCH_KERNEL(T, 6); \ + else if (head_dim == 256) \ + LAUNCH_KERNEL(T, 8) + void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res, const paddle::Tensor &decoder_res, const paddle::Tensor &seq_lens_encoder, @@ -60,41 +96,20 @@ void MergePrefillDecodeOutput(const paddle::Tensor &encoder_res, const int head_num, const int head_dim, const int max_token) { - if (head_dim != 128) { - PD_THROW("Only supported head_dim = 128"); + if (head_dim != 128 && head_dim != 192 && head_dim != 256) { + PD_THROW("Only supported head_dim = 128, 192 or 256"); } const int batch_size = seq_lens_encoder.shape()[0]; - constexpr int warps = 4; + const int warps = head_dim / 32; const int tokens_block = (max_token + warps - 1) / warps; - dim3 grid_dims; - grid_dims.x = batch_size; - grid_dims.y = head_num; - grid_dims.z = tokens_block; + dim3 grid_dims(batch_size, head_num, tokens_block); if (encoder_res.dtype() == paddle::DataType::FLOAT16) { using T = phi::dtype::float16; - FillEncoderDecoderResKernel - <<>>( - const_cast(encoder_res.data()), - const_cast(decoder_res.data()), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - seq_lens_this_time.data(), - cu_seq_q.data(), - head_num, - head_dim); + LAUNCH_KERNEL_BY_HEAD_DIM(T); } else if (encoder_res.dtype() == paddle::DataType::BFLOAT16) { using T = phi::dtype::bfloat16; - FillEncoderDecoderResKernel - <<>>( - const_cast(encoder_res.data()), - const_cast(decoder_res.data()), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - seq_lens_this_time.data(), - cu_seq_q.data(), - head_num, - head_dim); + LAUNCH_KERNEL_BY_HEAD_DIM(T); } } diff --git a/custom_ops/gpu_ops/moe/deepgemm_preprocess.cu b/custom_ops/gpu_ops/moe/deepgemm_preprocess.cu index 6eda3598cd3..4316aa5cbda 100644 --- a/custom_ops/gpu_ops/moe/deepgemm_preprocess.cu +++ b/custom_ops/gpu_ops/moe/deepgemm_preprocess.cu @@ -15,10 +15,11 @@ #include "helper.h" #include "paddle/extension.h" -template +template __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids, int32_t *__restrict__ res, int32_t *__restrict__ res_padded, + int32_t *__restrict__ res_padded_cumsum, size_t numel, int num_experts) { extern __shared__ int32_t tokens_per_ep[]; @@ -35,48 +36,81 @@ __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids, __syncthreads(); - for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) { - res[i] = tokens_per_ep[i]; - res_padded[i] = (res[i] + 127) / 128 * 128; + if constexpr (kComputeCumsum) { + if (threadIdx.x == 0) { + int32_t running_sum = 0; + for (int i = 0; i < num_experts; i++) { + int32_t count = tokens_per_ep[i]; + int32_t padded = (count + 127) / 128 * 128; + res[i] = count; + res_padded[i] = padded; + running_sum += padded; + res_padded_cumsum[i] = running_sum; + } + } + } else { + for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) { + res[i] = tokens_per_ep[i]; + res_padded[i] = (tokens_per_ep[i] + 127) / 128 * 128; + } } } std::vector count_tokens_per_expert_func( - const paddle::Tensor &topk_ids, int64_t num_experts) { + const paddle::Tensor &topk_ids, + int64_t num_experts, + bool compute_padded_cumsum) { int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1]; + int64_t num_rows = compute_padded_cumsum ? 3 : 2; auto token_nums_per_expert = paddle::empty( - {2, num_experts}, paddle::DataType::INT32, topk_ids.place()); + {num_rows, num_experts}, paddle::DataType::INT32, topk_ids.place()); auto stream = topk_ids.stream(); using scalar_t = int64_t; - // CUDA_CHECK(cudaGetLastError()); - cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>( - topk_ids.data(), - token_nums_per_expert.data(), - token_nums_per_expert.data() + num_experts, - topk_ids_numel, - num_experts); + if (compute_padded_cumsum) { + cuda_kernel + <<<1, 1024, num_experts * sizeof(int32_t), stream>>>( + topk_ids.data(), + token_nums_per_expert.data(), + token_nums_per_expert.data() + num_experts, + token_nums_per_expert.data() + 2 * num_experts, + topk_ids_numel, + num_experts); + } else { + cuda_kernel + <<<1, 1024, num_experts * sizeof(int32_t), stream>>>( + topk_ids.data(), + token_nums_per_expert.data(), + token_nums_per_expert.data() + num_experts, + nullptr, + topk_ids_numel, + num_experts); + } - // CUDA_CHECK(cudaGetLastError()); return {token_nums_per_expert}; } std::vector count_tokens_per_expert_func_infer_dtype( - const paddle::DataType &topk_ids_dtype, int64_t num_experts) { + const paddle::DataType &topk_ids_dtype, + int64_t num_experts, + bool compute_padded_cumsum) { return {paddle::DataType::INT32}; } std::vector> count_tokens_per_expert_func_infer_shape( - const std::vector &topk_ids_shape, int64_t num_experts) { - return {{2, num_experts}}; + const std::vector &topk_ids_shape, + int64_t num_experts, + bool compute_padded_cumsum) { + int64_t num_rows = compute_padded_cumsum ? 3 : 2; + return {{num_rows, num_experts}}; } PD_BUILD_STATIC_OP(count_tokens_per_expert_func) .Inputs({"topk_ids"}) .Outputs({"token_nums_per_expert"}) - .Attrs({"num_experts:int64_t"}) + .Attrs({"num_experts:int64_t", "compute_padded_cumsum:bool"}) .SetKernelFn(PD_KERNEL(count_tokens_per_expert_func)) .SetInferShapeFn(PD_INFER_SHAPE(count_tokens_per_expert_func_infer_shape)) .SetInferDtypeFn(PD_INFER_DTYPE(count_tokens_per_expert_func_infer_dtype)); diff --git a/custom_ops/gpu_ops/moe/moe_align_kernel.cu b/custom_ops/gpu_ops/moe/moe_align_kernel.cu new file mode 100644 index 00000000000..4d2a01d8dd9 --- /dev/null +++ b/custom_ops/gpu_ops/moe/moe_align_kernel.cu @@ -0,0 +1,604 @@ + +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Reference +// https://raw.githubusercontent.com/sgl-project/sglang/refs/heads/main/sgl-kernel/csrc/moe/moe_align_kernel.cu +// Licensed under Apache License 2.0 +// with further performance optimizations applied. + +#include + +#include "helper.h" +#include "paddle/extension.h" + +#define VEC_SIZE 4 +using Vec = int4; + +template +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + +#ifdef __CUDA_ARCH__ +__device__ __forceinline__ int warp_exclusive_scan( + int v, unsigned mask = 0xffffffffu) { + int original = v; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int n = __shfl_up_sync(mask, v, offset); + if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n; + } + return v - original; +} +#endif + +template +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + bool pad_sorted_token_ids, + const int32_t scan_size, + int32_t max_num_tokens_padded) { + // Use a separate thread block to populate sorted_token_ids + if (blockIdx.x == 1) { + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = threadIdx.x; i < total_vecs; i += blockDim.x) { + out_ptr[i] = fill_vec; + } + } + return; + } + + extern __shared__ int32_t smem[]; + int32_t* shared_counts = smem; // [num_experts] + int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] + int32_t* scan_buf = prefix + num_experts + 1; // [scan_size] + __shared__ int32_t s_total_tokens_post_pad; + + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + if (tid < num_experts) { + shared_counts[tid] = 0; + } + + __syncthreads(); + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i] + 1; + atomicAdd(&shared_counts[expert_id], 1); + } + + __syncthreads(); + + int32_t padded_count = 0; + if (tid < num_experts) { + int32_t count = shared_counts[tid]; + padded_count = (count + block_size - 1) / block_size * block_size; + scan_buf[tid] = padded_count; + } + +#ifndef __CUDA_ARCH__ // HIP + + if (tid >= num_experts && tid < scan_size) { + scan_buf[tid] = 0; + } + + __syncthreads(); + + // Blelloch scan + int offset = 1; +#pragma unroll + for (int d = scan_size >> 1; d > 0; d >>= 1) { + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + scan_buf[bi] += scan_buf[ai]; + } + offset <<= 1; + __syncthreads(); + } + + // down-sweep + if (tid == 0) { + prefix[num_experts] = scan_buf[scan_size - 1]; + scan_buf[scan_size - 1] = 0; + } + __syncthreads(); + +#pragma unroll + for (int d = 1; d < scan_size; d <<= 1) { + offset >>= 1; + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + if (bi < scan_size) { + int temp = scan_buf[ai]; + scan_buf[ai] = scan_buf[bi]; + scan_buf[bi] += temp; + } + } + __syncthreads(); + } + + if (tid < num_experts) { + prefix[tid] = scan_buf[tid]; + } + + if (tid == 0) { + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + +#else // CUDA + + // Intra warp prefix sum + int32_t* warp_sums = scan_buf + scan_size; // [<= 32] + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid & (WARP_SIZE - 1); + const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE; + const int warp_sum = warp_exclusive_scan(padded_count) + padded_count; + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum; + __syncthreads(); + + // warp0 accumulate all the block's prefix sum + if (tid < WARP_SIZE) { + int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0; + int incl = warp_exclusive_scan(val) + val; + warp_sums[tid] = incl; + } + __syncthreads(); + + // Every thread obtains the whole block's sum + if (tid == 0) { + prefix[num_experts] = warp_sums[num_warps_for_scan - 1]; + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + + // Fill 0 to scan_buf extended area (tid >= num_expert) + if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0; + __syncthreads(); + + // Perform 2 level exclusive-prefix-sum to scan_buf + int v = (tid < scan_size) ? scan_buf[tid] : 0; + int pre = warp_exclusive_scan(v); + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v; + __syncthreads(); + + if (warp_id == 0) { + int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0; + warp_sums[lane_id] = warp_exclusive_scan(val); + } + __syncthreads(); + + int offset = warp_sums[warp_id]; + if (tid < scan_size) scan_buf[tid] = pre + offset; + __syncthreads(); + + // Write prefix[0..num_experts - 1] and cumsum + if (tid < num_experts) prefix[tid] = scan_buf[tid]; +#endif + + if (tid <= num_experts) { + cumsum[tid] = prefix[tid]; + } + // fill expert_ids + const int32_t num_blocks = s_total_tokens_post_pad / block_size; + for (int32_t i = tid; i < num_blocks; i += stride) { + int32_t block_start = i * block_size; + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (prefix[mid] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 2; + } +} + +// ===== Cooperative fused kernel for large batch (single launch, grid.sync) + +namespace cg = cooperative_groups; + +template +__global__ void moe_align_block_size_cooperative_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t* __restrict__ global_counts, // [num_experts+1], zeroed by caller + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids, + int32_t max_num_tokens_padded) { + cg::grid_group grid = cg::this_grid(); + + extern __shared__ int32_t smem[]; + // smem layout: [num_experts] local_hist + [num_experts+1] expert_starts + int32_t* local_hist = smem; + int32_t* expert_starts_local = smem + num_experts; + + const int bid = blockIdx.x; + const int tid = threadIdx.x; + const int nthreads = blockDim.x; + const int nblocks = gridDim.x; + + __shared__ int32_t s_total; + + // ===== Stage 0: Cooperative initialization ===== + // Fill sorted_token_ids with sentinel value (all blocks cooperate) + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = + static_cast(numel); + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = bid * nthreads + tid; i < total_vecs; + i += nblocks * nthreads) { + out_ptr[i] = fill_vec; + } + } + + // Initialize local histogram to 0 + for (int i = tid; i < num_experts; i += nthreads) { + local_hist[i] = 0; + } + __syncthreads(); + + // ===== Stage 1: Local histogram + global atomic merge ===== + for (size_t i = (size_t)bid * nthreads + tid; i < numel; + i += (size_t)nblocks * nthreads) { + int expert_id = static_cast(topk_ids[i]) + 1; + atomicAdd(&local_hist[expert_id], 1); + } + __syncthreads(); + + // Merge local counts into global via atomic fetch-and-add. + // Return value = prefix_before (reuse local_hist to store it). + for (int i = tid; i < num_experts; i += nthreads) { + int32_t count = local_hist[i]; + int32_t prefix_before = atomicAdd(&global_counts[i], count); + local_hist[i] = prefix_before; + } + + grid.sync(); // all histograms merged, global_counts has totals + + // ===== Stage 2: Redundant prefix sum per block ===== + if (tid == 0) { + int32_t running_sum = 0; + for (int i = 0; i < num_experts; i++) { + int32_t count = global_counts[i]; + int32_t padded = (count + block_size - 1) / block_size * block_size; + expert_starts_local[i] = running_sum; + running_sum += padded; + } + expert_starts_local[num_experts] = running_sum; // total + s_total = running_sum; + } + + grid.sync(); + + // Block 0 writes total_tokens_post_pad and cumsum (global_counts) + if (bid == 0) { + // Write expert starts to global_counts for external consumers + if (tid <= num_experts) { + global_counts[tid] = expert_starts_local[tid]; + } + if (tid == 0) { + *total_tokens_post_pad = s_total; + } + } + + // ===== Stage 3: Fill expert_ids (all blocks cooperate) ===== + const int32_t num_blocks_out = s_total / block_size; + for (int32_t i = bid * nthreads + tid; i < num_blocks_out; + i += nblocks * nthreads) { + int32_t block_start = i * block_size; + // Binary search: find the expert whose start <= block_start < next start + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (expert_starts_local[mid + 1] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 1; // expert indexing: topk_ids uses +1 offset + } + + // ===== Stage 4: Scatter tokens using shared memory atomics ===== + // local_hist[i] currently holds prefix_before for this block. + // We do atomic_add on local_hist to get each token's rank within the expert, + // then add expert_starts_local to get the final position. + for (size_t i = (size_t)bid * nthreads + tid; i < numel; + i += (size_t)nblocks * nthreads) { + int expert_id = static_cast(topk_ids[i]) + 1; + int32_t rank = atomicAdd(&local_hist[expert_id], 1); + int32_t pos = rank + expert_starts_local[expert_id]; + sorted_token_ids[pos] = i; + } +} + +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids, + int32_t max_num_tokens_padded) { + // Adapted from + // https://github.com/vllm-project/vllm/pull/29642/files#diff-5647b1413f4ae9aacba904eca8f8a8aee9079321eadff4c10101a2c6962dcc53R226 + // Use an additional group of threads to fill sorted_token_ids. + // Since the kernel will use sorted_token_ids afterward, + // we fill sorted_token_ids within the same threadblock to make + // synchronization easier. + if (threadIdx.x < fill_threads) { + // Initialize sorted_token_ids with numel + if (pad_sorted_token_ids) { + for (int32_t it = threadIdx.x; it < max_num_tokens_padded; + it += fill_threads) { + sorted_token_ids[it] = numel; + } + } + // Three __syncthreads() corresponding to the other threads + __syncthreads(); + __syncthreads(); + __syncthreads(); + return; + } + + const size_t tid = threadIdx.x - fill_threads; + const size_t stride = blockDim.x - fill_threads; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(tid + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + ++tokens_cnts[(tid + 1) * num_experts + expert_id]; + } + + __syncthreads(); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= stride; ++i) { + tokens_cnts[i * num_experts + tid] += + tokens_cnts[(i - 1) * num_experts + tid]; + } + } + + __syncthreads(); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * + block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[i / block_size] = tid - 1; + } + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = + tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } +} + +template +void moe_align_block_size(const paddle::Tensor& topk_ids, + int64_t num_experts, + int64_t block_size, + paddle::Tensor& sorted_token_ids, + paddle::Tensor& experts_ids, + paddle::Tensor& num_tokens_post_pad, + paddle::Tensor& cumsum_buffer, + bool pad_sorted_token_ids) { + int threads = 1024; + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + auto stream = topk_ids.stream(); + + const size_t numel = topk_ids.numel(); + const int64_t max_num_tokens_padded = sorted_token_ids.shape()[0]; + + bool small_batch_expert_mode = (numel < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t expert_threads = max((int32_t)num_experts, WARP_SIZE); + constexpr int32_t fill_threads = 256; + const int32_t shared_mem_size = + ((expert_threads + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + auto small_batch_expert_kernel = + moe_align_block_size_small_batch_expert_kernel; + small_batch_expert_kernel<<<1, + fill_threads + expert_threads, + shared_mem_size, + stream>>>(topk_ids.data(), + sorted_token_ids.data(), + experts_ids.data(), + num_tokens_post_pad.data(), + num_experts, + block_size, + numel, + pad_sorted_token_ids, + max_num_tokens_padded); + } else { + // Use cooperative fused kernel for large inputs where multi-block + // parallelism outweighs cooperative launch overhead + if (numel >= 16384) { + const int coop_threads = 256; + const size_t coop_smem = (2 * num_experts + 1) * sizeof(int32_t); + + auto coop_kernel = moe_align_block_size_cooperative_kernel; + + static int cached_max_blocks_per_sm = 0; + static int cached_num_sms = 0; + if (cached_num_sms == 0) { + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&cached_max_blocks_per_sm, + (void*)coop_kernel, + coop_threads, + coop_smem); + int device_id; + cudaGetDevice(&device_id); + cudaDeviceGetAttribute( + &cached_num_sms, cudaDevAttrMultiProcessorCount, device_id); + } + + int max_coop_blocks = cached_max_blocks_per_sm * cached_num_sms; + int desired_blocks = std::max( + 1, std::min(256, static_cast(numel / (coop_threads * 4)))); + int coop_blocks = std::min(desired_blocks, max_coop_blocks); + if (coop_blocks < 1) coop_blocks = 1; + + const scalar_t* topk_ids_ptr = topk_ids.data(); + int32_t* sorted_token_ids_ptr = sorted_token_ids.data(); + int32_t* experts_ids_ptr = experts_ids.data(); + int32_t* num_tokens_post_pad_ptr = num_tokens_post_pad.data(); + int32_t* cumsum_ptr = cumsum_buffer.data(); + int32_t num_experts_i32 = static_cast(num_experts); + int32_t block_size_i32 = static_cast(block_size); + size_t numel_val = numel; + bool pad_val = pad_sorted_token_ids; + int32_t max_padded_i32 = static_cast(max_num_tokens_padded); + + void* args[] = {&topk_ids_ptr, + &sorted_token_ids_ptr, + &experts_ids_ptr, + &num_tokens_post_pad_ptr, + &cumsum_ptr, + &num_experts_i32, + &block_size_i32, + &numel_val, + &pad_val, + &max_padded_i32}; + + cudaError_t err = cudaLaunchCooperativeKernel((void*)coop_kernel, + dim3(coop_blocks), + dim3(coop_threads), + args, + coop_smem, + stream); + + if (err == cudaSuccess) { + return; + } + // Fall through to original path if cooperative launch failed + } + + // Original 2-kernel approach (for medium inputs or cooperative fallback) + auto align_kernel = moe_align_block_size_kernel; + + const size_t scan_size = next_pow_2(num_experts); + const size_t shared_mem_size = + (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * + sizeof(int32_t); + align_kernel<<<2, threads, shared_mem_size, stream>>>( + topk_ids.data(), + sorted_token_ids.data(), + experts_ids.data(), + num_tokens_post_pad.data(), + num_experts, + block_size, + numel, + cumsum_buffer.data(), + pad_sorted_token_ids, + scan_size, + max_num_tokens_padded); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = ((int)numel + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + auto sort_kernel = count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( + topk_ids.data(), + sorted_token_ids.data(), + cumsum_buffer.data(), + numel); + } +} + +// Explicit instantiations for use from other translation units (e.g. +// tritonmoe_preprocess.cu) +template void moe_align_block_size(const paddle::Tensor&, + int64_t, + int64_t, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + bool); +template void moe_align_block_size(const paddle::Tensor&, + int64_t, + int64_t, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + paddle::Tensor&, + bool); diff --git a/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu b/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu index 071e0a9b418..eb680ea744e 100644 --- a/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu +++ b/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu @@ -15,83 +15,40 @@ #include "helper.h" #include "paddle/extension.h" -#define CEILDIV(a, b) (((a + b - 1) / b)) - template -__global__ void count_and_sort_expert_tokens_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, - size_t numel) { - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = blockDim.x * gridDim.x; - - for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i]; - int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); - sorted_token_ids[rank_post_pad] = i; - } -} - -template -__global__ void moe_align_block_size_kernel( - const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, - int32_t GEMM_BLOCK_SIZE_M, - size_t numel, - int32_t* __restrict__ cumsum_buffer) { - __shared__ int32_t tokens_per_ep[num_experts]; - - for (int i = threadIdx.x; i < num_experts; i += blockDim.x) { - tokens_per_ep[i] = 0; - } - - __syncthreads(); - - for (int i = threadIdx.x; i < numel; i += blockDim.x) { - int expert_id = topk_ids[i]; - atomicAdd(&tokens_per_ep[expert_id], 1); - } - - __syncthreads(); - - if (threadIdx.x == 0) { - cumsum_buffer[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - int expert_count = tokens_per_ep[i - 1]; - cumsum_buffer[i] = - cumsum_buffer[i - 1] + - CEILDIV(expert_count, GEMM_BLOCK_SIZE_M) * GEMM_BLOCK_SIZE_M; - } - *total_tokens_post_pad = cumsum_buffer[num_experts]; - } - - __syncthreads(); - - if (threadIdx.x < num_experts) { - for (int i = cumsum_buffer[threadIdx.x]; i < cumsum_buffer[threadIdx.x + 1]; - i += GEMM_BLOCK_SIZE_M) { - expert_ids[i / GEMM_BLOCK_SIZE_M] = threadIdx.x; - } - } -} +void moe_align_block_size(const paddle::Tensor& topk_ids, + int64_t num_experts, + int64_t block_size, + paddle::Tensor& sorted_token_ids, + paddle::Tensor& experts_ids, + paddle::Tensor& num_tokens_post_pad, + paddle::Tensor& cumsum_buffer, + bool pad_sorted_token_ids); std::vector> tritonmoe_preprocessInferShape( const std::vector& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) { - int topk_ids_numel = topk_ids[0] * topk_ids[1]; - int max_num_tokens_padded = - topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1); + int topk_ids_numel = 1; + for (int64_t dim : topk_ids) { + topk_ids_numel *= static_cast(dim); + } + int max_num_tokens_padded; + if (topk_ids_numel < num_experts + 1) { + max_num_tokens_padded = topk_ids_numel * GEMM_BLOCK_SIZE_M; + } else { + max_num_tokens_padded = + topk_ids_numel + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1); + } std::vector sorted_ids = {max_num_tokens_padded}; - int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M; - std::vector expert_ids = {max_num_m_blocks}; + int max_num_m_blocks = + (max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) / GEMM_BLOCK_SIZE_M; + std::vector experts_ids = {max_num_m_blocks}; std::vector num_tokens_post_pad = {1}; - return {sorted_ids, expert_ids, num_tokens_post_pad}; + return {sorted_ids, experts_ids, num_tokens_post_pad}; } std::vector tritonmoe_preprocessIferDtype( @@ -127,76 +84,50 @@ std::vector tritonmoe_preprocess_kernel( const paddle::Tensor& topk_ids, int64_t num_experts, int64_t GEMM_BLOCK_SIZE_M) { - int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1]; - int max_num_tokens_padded = - topk_ids_numel + num_experts * (GEMM_BLOCK_SIZE_M - 1); + int topk_ids_numel = static_cast(topk_ids.numel()); + + int max_num_tokens_padded; + if (topk_ids_numel < num_experts + 1) { + max_num_tokens_padded = topk_ids_numel * GEMM_BLOCK_SIZE_M; + } else { + max_num_tokens_padded = + topk_ids_numel + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1); + } auto sorted_ids = paddle::full({max_num_tokens_padded}, topk_ids_numel, paddle::DataType::INT32, topk_ids.place()); - int max_num_m_blocks = max_num_tokens_padded / GEMM_BLOCK_SIZE_M; + int max_num_m_blocks = + (max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) / GEMM_BLOCK_SIZE_M; - auto expert_ids = paddle::empty( + auto experts_ids = paddle::empty( {max_num_m_blocks}, paddle::DataType::INT32, topk_ids.place()); auto num_tokens_post_pad = paddle::empty({1}, paddle::DataType::INT32, topk_ids.place()); - auto cumsum_buffer = paddle::empty( - {num_experts + 1}, paddle::DataType::INT32, topk_ids.place()); + auto cumsum_buffer = paddle::zeros( + {num_experts + 2}, paddle::DataType::INT32, topk_ids.place()); - auto stream = topk_ids.stream(); using scalar_t = int64_t; - -#define run_align_kernel(num_experts) \ - auto align_kernel = moe_align_block_size_kernel; \ - align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data(), \ - expert_ids.data(), \ - num_tokens_post_pad.data(), \ - GEMM_BLOCK_SIZE_M, \ - topk_ids_numel, \ - cumsum_buffer.data()); - - if (num_experts == 8) { - run_align_kernel(8); - } else if (num_experts == 256) { - run_align_kernel(256); - } else if (num_experts == 2) { - run_align_kernel(2); - } else if (num_experts == 64) { - run_align_kernel(64); - } else if (num_experts == 128) { - run_align_kernel(128); - } else if (num_experts == 160) { - run_align_kernel(160); - } else if (num_experts == 32) { - run_align_kernel(32); - } else { - PD_THROW("Not support num_experts: %d", num_experts); - } - - const int block_threads = 256; - const int num_blocks = CEILDIV(topk_ids_numel, block_threads); - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - - auto sort_kernel = count_and_sort_expert_tokens_kernel; - - sort_kernel<<>>( - topk_ids.data(), - sorted_ids.data(), - cumsum_buffer.data(), - topk_ids_numel); - - return {sorted_ids, expert_ids, num_tokens_post_pad}; + moe_align_block_size(topk_ids, + num_experts + 1, + GEMM_BLOCK_SIZE_M, + sorted_ids, + experts_ids, + num_tokens_post_pad, + cumsum_buffer, + true); + + return {sorted_ids, experts_ids, num_tokens_post_pad}; } PD_BUILD_STATIC_OP(tritonmoe_preprocess) .Inputs({"topk_ids"}) .Attrs({"num_experts: int64_t", "GEMM_BLOCK_SIZE_M: int64_t"}) - .Outputs({"sorted_ids", "expert_ids", "num_tokens_post_pad"}) + .Outputs({"sorted_ids", "experts_ids", "num_tokens_post_pad"}) .SetKernelFn(PD_KERNEL(tritonmoe_preprocess_kernel)) .SetInferShapeFn(PD_INFER_SHAPE(tritonmoe_preprocessInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(tritonmoe_preprocessIferDtype)); diff --git a/custom_ops/gpu_ops/save_output_msg_with_topk.cc b/custom_ops/gpu_ops/save_output_msg_with_topk.cc index 0a7d2ab6eac..3069cb3929b 100644 --- a/custom_ops/gpu_ops/save_output_msg_with_topk.cc +++ b/custom_ops/gpu_ops/save_output_msg_with_topk.cc @@ -109,20 +109,21 @@ void SaveOutMmsgTopK(const paddle::Tensor& x, : -inference_msg_id_from_env; int bsz = x.shape()[0]; int max_num_logprobs = logprob_token_ids.shape()[1]; - msg_sed.mtext[1] = bsz; + // Pack bsz (low 16 bits) and max_num_logprobs (high 16 bits) into mtext[1]. + // token_processor unpacks both fields to avoid reading unused topk slots. + msg_sed.mtext[1] = bsz | (max_num_logprobs << 16); for (int i = 0; i < bsz; i++) { - for (int j = 0; j < K + 1; j++) { - const int64_t offset = i * (K + 1) + j; + // Loop only over actual logprob columns (max_num_logprobs) instead of the + // fixed K+1=21, and use max_num_logprobs as the stride so data is packed + // densely in the message buffer. + for (int j = 0; j < max_num_logprobs; j++) { + const int64_t offset = i * max_num_logprobs + j; if (j == 0) { msg_sed.mtext[offset + 2] = (int)x_data[i]; - msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j]; - } else if (j < max_num_logprobs) { - msg_sed.mtext[offset + 2] = - (int)logprob_token_ids_data[i * max_num_logprobs + j]; - msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j]; + msg_sed.mtext_f[offset] = logprob_scores_data[offset]; } else { - msg_sed.mtext[offset + 2] = -1; - msg_sed.mtext_f[offset] = 0.0; + msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[offset]; + msg_sed.mtext_f[offset] = logprob_scores_data[offset]; } if (preempted_idx_data[i] == 1) { msg_sed.mtext[offset + 2] = -9; diff --git a/custom_ops/gpu_ops/speculate_decoding/build_sampling_params_logprob.cu b/custom_ops/gpu_ops/speculate_decoding/build_sampling_params_logprob.cu new file mode 100644 index 00000000000..790ba551485 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/build_sampling_params_logprob.cu @@ -0,0 +1,129 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +__global__ void BuildSamplingParamLogProbKernel( + T* output_params, + const T* input_params, + const int32_t* token_num_per_batch, + const int64_t token_num_output_cpu) { + const int bi = blockIdx.x; + const int tid = threadIdx.x; + + // Compute start offset: sum of token_num_per_batch[0..bi-1] + int start_offset = 0; + for (int i = 0; i < bi; i++) { + start_offset += token_num_per_batch[i]; + } + + int cur_token_num = token_num_per_batch[bi]; + + if (cur_token_num <= 0) { + return; + } + + // Read per-batch param into register + T val = input_params[bi]; + + // Fill output_params with bounds check against total output size + for (int i = tid; i < cur_token_num; i += blockDim.x) { + int64_t idx = static_cast(start_offset) + i; + if (idx < token_num_output_cpu) { + output_params[idx] = val; + } + } +} + +std::vector BuildSamplingParamLogProb( + const paddle::Tensor& input_params, + const paddle::Tensor& token_num_per_batch, + const int64_t token_num_output_cpu) { + auto cu_stream = input_params.stream(); + // Initialize output to safe defaults for use as divisors: + // int32/float32 -> 1, bool -> false + paddle::Tensor output_params; + switch (input_params.dtype()) { + case paddle::DataType::BOOL: + output_params = paddle::full({token_num_output_cpu}, + false, + input_params.dtype(), + input_params.place()); + break; + case paddle::DataType::INT32: + output_params = paddle::full({token_num_output_cpu}, + 1, + input_params.dtype(), + input_params.place()); + break; + case paddle::DataType::FLOAT32: + output_params = paddle::full({token_num_output_cpu}, + 1.0f, + input_params.dtype(), + input_params.place()); + break; + default: + PD_THROW( + "Unsupported data type for BuildSamplingParamLogProb. " + "Only bool, int32, float32 are supported."); + } + + int32_t num_blocks = token_num_per_batch.shape()[0]; + switch (input_params.dtype()) { + case paddle::DataType::BOOL: { + BuildSamplingParamLogProbKernel<<>>( + output_params.data(), + input_params.data(), + token_num_per_batch.data(), + token_num_output_cpu); + break; + } + case paddle::DataType::INT32: { + BuildSamplingParamLogProbKernel + <<>>( + output_params.data(), + input_params.data(), + token_num_per_batch.data(), + token_num_output_cpu); + break; + } + case paddle::DataType::FLOAT32: { + BuildSamplingParamLogProbKernel<<>>( + output_params.data(), + input_params.data(), + token_num_per_batch.data(), + token_num_output_cpu); + break; + } + default: { + PD_THROW( + "Unsupported data type for BuildSamplingParamLogProb. " + "Only bool, int32, float32 are supported."); + } + } + + return {output_params}; +} + +PD_BUILD_STATIC_OP(build_sampling_params_logprob) + .Inputs({"input_params", "token_num_per_batch"}) + .Outputs({"output_params"}) + .Attrs({"token_num_output_cpu: int64_t"}) + .SetKernelFn(PD_KERNEL(BuildSamplingParamLogProb)); diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc new file mode 100644 index 00000000000..0ec49b854ae --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc @@ -0,0 +1,216 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" +#include "../../custom_ftok.h" +#include "../speculate_logprob_msg.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void MTPSaveFirstTokenWithTopK(const paddle::Tensor& sampled_token_ids, + const paddle::Tensor& logprob_token_ids, + const paddle::Tensor& logprob_scores, + const paddle::Tensor& logprob_ranks, + const paddle::Tensor& token_num_per_batch, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& not_need_stop, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& preempted_idx, + int message_flag, // Target: 3, Draft: 4 + int64_t rank_id, + bool save_each_rank) { + if (!save_each_rank && rank_id > 0) { + return; + } + + int max_draft_tokens = sampled_token_ids.shape()[1]; + int bsz = token_num_per_batch.shape()[0]; + + auto sampled_token_ids_cpu = + sampled_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_token_ids_cpu = + logprob_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false); + auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false); + auto token_num_per_batch_cpu = + token_num_per_batch.copy_to(paddle::CPUPlace(), false); + auto cu_batch_token_offset_cpu = + cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); + auto seq_lens_decoder_cpu = + seq_lens_decoder.copy_to(paddle::CPUPlace(), true); + auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true); + int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); + int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); + float* logprob_scores_data = logprob_scores_cpu.data(); + int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); + int* token_num_per_batch_data = token_num_per_batch_cpu.data(); + int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + int* seq_lens_decoder_data = seq_lens_decoder_cpu.data(); + int64_t* prompt_lens_data = prompt_lens_cpu.data(); + const int32_t* preempted_idx_data = preempted_idx.data(); + + static struct msgdata msg_sed; + int msg_queue_id = 1; + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; +#endif + } + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is perserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + static key_t key = custom_ftok("/dev/shm", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + msg_sed.mtype = 1; + msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env + : -inference_msg_id_from_env; + // Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into + // meta[1]. Receiver unpacks both to avoid reading unused topk slots. + int max_num_logprobs = logprob_token_ids.shape()[1]; + msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); + msg_sed.meta[2] = bsz; + for (int i = 0; i < bsz; i++) { + int cur_token_num; + if (seq_lens_decoder_data[i] < prompt_lens_data[i] || + token_num_per_batch_data[i] == 0) { + // chunk prefill or stop slots + cur_token_num = 0; + } else { + cur_token_num = token_num_per_batch_data[i] + 1; + } + msg_sed.meta[3 + i] = cur_token_num; + if (preempted_idx_data[i] == 1) { + msg_sed.meta[3 + i] = -9; + } + + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + int token_offset = cu_batch_token_offset_data[i]; + for (int j = 0; j < cur_token_num; j++) { + // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write + // max_num_logprobs columns to avoid filling unused topk slots. + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + if (j == 0) { + // first token has full logprobs + for (int k = 0; k < max_num_logprobs; k++) { + if (k == 0) { + cur_tokens[k] = + (int)sampled_token_ids_data[i * max_draft_tokens + j]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; + } else { + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * max_num_logprobs + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; + } + } + cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; + } else { + // draft token only has token_id + cur_tokens[0] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; + } + } + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "msg data: " << std::endl; + std::cout << "stop_flag: " << msg_sed.meta[0] + << ", message_flag: " << (msg_sed.meta[1] & 0xFF) + << ", max_num_logprobs: " << (msg_sed.meta[1] >> 8) + << ", bsz: " << msg_sed.meta[2] << std::endl; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_sed.meta[3 + i]; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + std::cout << "tokens: "; + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << cur_tokens[k] << " "; + } + std::cout << std::endl; + std::cout << "scores: "; + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << cur_scores[k] << " "; + } + std::cout << std::endl; + std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl; + } + } + std::cout << std::endl; +#endif + if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) { + printf("full msg buffer\n"); + } +} + +PD_BUILD_STATIC_OP(mtp_save_first_token_with_topk) + .Inputs({"sampled_token_ids", + "logprob_token_ids", + "logprob_scores", + "logprob_ranks", + "token_num_per_batch", + "cu_batch_token_offset", + "not_need_stop", + "seq_lens_decoder", + "prompt_lens", + "preempted_idx"}) + .Attrs({"message_flag: int", "rank_id: int64_t", "save_each_rank: bool"}) + .SetKernelFn(PD_KERNEL(MTPSaveFirstTokenWithTopK)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc index 76ff5e190d8..3e5ed2430b0 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -19,27 +19,12 @@ #include #include "paddle/extension.h" #include "../custom_ftok.h" +#include "speculate_logprob_msg.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -#define MAX_BSZ 512 -#define K 20 -#define MAX_DRAFT_TOKEN_NUM 6 - -struct batch_msgdata { - int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - int ranks[MAX_DRAFT_TOKEN_NUM]; -}; - -struct msgdata { - long mtype; - int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums - batch_msgdata mtext[MAX_BSZ]; -}; - void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, const paddle::Tensor& output_scores, const paddle::Tensor& output_ranks, @@ -90,25 +75,28 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, int bsz = msg_rcv.meta[2]; output_tokens_data[0] = (int64_t)msg_rcv.meta[0]; + // Unpack message_flag (low 8 bits) and actual_topk (high 24 bits) from + // meta[1]. Keep packed value; Python unpacks message_flag and actual_topk. output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; + int actual_topk = msg_rcv.meta[1] >> 8; - int output_tokens_offset = 3 + MAX_BSZ; + int output_tokens_offset = 3 + SPEC_LOGPROB_MAX_BSZ; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_rcv.meta[3 + i]; output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums auto* cur_output_token = output_tokens_data + output_tokens_offset + - i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_output_score = - output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)); auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; for (int j = 0; j < cur_token_num; j++) { - for (int k = 0; k < real_k + 1; k++) { - cur_output_token[j * (K + 1) + k] = - (int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k]; - cur_output_score[j * (K + 1) + k] = - cur_batch_msg_rcv->scores[j * (K + 1) + k]; + for (int k = 0; k < actual_topk; k++) { + cur_output_token[j * (SPEC_LOGPROB_K + 1) + k] = + (int64_t)cur_batch_msg_rcv->tokens[j * (SPEC_LOGPROB_K + 1) + k]; + cur_output_score[j * (SPEC_LOGPROB_K + 1) + k] = + cur_batch_msg_rcv->scores[j * (SPEC_LOGPROB_K + 1) + k]; } output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] = (int64_t)cur_batch_msg_rcv->ranks[j]; @@ -117,24 +105,27 @@ void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, #ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << output_tokens_data[0] - << ", message_flag: " << output_tokens_data[1] + << ", message_flag: " << (output_tokens_data[1] & 0xFF) + << ", max_num_logprobs: " << (output_tokens_data[1] >> 8) << ", bsz: " << output_tokens_data[2] << std::endl; for (int i = 0; i < output_tokens_data[2]; i++) { int cur_token_num = output_tokens_data[3 + i]; std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; for (int j = 0; j < cur_token_num; j++) { std::cout << "tokens: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << output_tokens_data[output_tokens_offset + - i * MAX_DRAFT_TOKEN_NUM * (K + 1) + - j * (K + 1) + k] + i * MAX_DRAFT_TOKEN_NUM * + (SPEC_LOGPROB_K + 1) + + j * (SPEC_LOGPROB_K + 1) + k] << " "; } std::cout << std::endl; std::cout << "scores: "; - for (int k = 0; k < K + 1; k++) { - std::cout << output_scores_data[i * MAX_DRAFT_TOKEN_NUM * (K + 1) + - j * (K + 1) + k] + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { + std::cout << output_scores_data[i * MAX_DRAFT_TOKEN_NUM * + (SPEC_LOGPROB_K + 1) + + j * (SPEC_LOGPROB_K + 1) + k] << " "; } std::cout << std::endl; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu index ca5d8353c3e..022c39bfb64 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu @@ -16,12 +16,12 @@ template __global__ inline void min_length_logits_process( - T *logits, - const int64_t *cur_len, - const int64_t *min_len, - const int64_t *eos_token_id, - const int *batch_id_per_token_output, - const int *cu_seqlens_q_output, + T* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -46,12 +46,12 @@ __global__ inline void min_length_logits_process( template <> __global__ inline void min_length_logits_process( - half *logits, - const int64_t *cur_len, - const int64_t *min_len, - const int64_t *eos_token_id, - const int *batch_id_per_token_output, - const int *cu_seqlens_q_output, + half* logits, + const int64_t* cur_len, + const int64_t* min_len, + const int64_t* eos_token_id, + const int* batch_id_per_token_output, + const int* cu_seqlens_q_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -74,11 +74,11 @@ __global__ inline void min_length_logits_process( } } -__global__ void update_repeat_times(const int64_t *token_ids_all, - const int64_t *prompt_lens, - const int64_t *cur_len, - int *repeat_times, - const int *batch_id_per_token_output, +__global__ void update_repeat_times(const int64_t* token_ids_all, + const int64_t* prompt_lens, + const int64_t* cur_len, + int* repeat_times, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -93,9 +93,9 @@ __global__ void update_repeat_times(const int64_t *token_ids_all, return; } int tid = threadIdx.x; - const int64_t *pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi]; - int *repeat_times_now = repeat_times + token_idx * length; - for (int i = tid; i < length_id; i += blockDim.x) { + const int64_t* pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi]; + int* repeat_times_now = repeat_times + token_idx * length; + for (int i = tid; i < cur_len[bi]; i += blockDim.x) { int64_t id = pre_ids_now[i]; if (id < 0) break; atomicAdd(&repeat_times_now[id], 1); @@ -104,13 +104,13 @@ __global__ void update_repeat_times(const int64_t *token_ids_all, template __global__ void update_value_by_repeat_times( - const int *repeat_times, - const T *penalty_scores, - const T *frequency_score, - const T *presence_score, - const float *temperatures, - T *logits, - const int *batch_id_per_token_output, + const int* repeat_times, + const T* penalty_scores, + const T* frequency_score, + const T* presence_score, + const float* temperatures, + T* logits, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -121,8 +121,8 @@ __global__ void update_value_by_repeat_times( if (bi < 0) return; if (bi >= bs) return; int tid = threadIdx.x; - T *logits_now = logits + token_idx * length; - const int *repeat_times_now = repeat_times + token_idx * length; + T* logits_now = logits + token_idx * length; + const int* repeat_times_now = repeat_times + token_idx * length; float alpha = static_cast(penalty_scores[bi]); float beta = static_cast(frequency_score[bi]); float gamma = static_cast(presence_score[bi]); @@ -138,10 +138,10 @@ __global__ void update_value_by_repeat_times( } template -__global__ void ban_bad_words(T *logits, - const int64_t *bad_tokens, - const int64_t *bad_tokens_len, - const int *batch_id_per_token_output, +__global__ void ban_bad_words(T* logits, + const int64_t* bad_tokens, + const int64_t* bad_tokens_len, + const int* batch_id_per_token_output, const int64_t token_num, const int64_t bs, const int64_t length, @@ -153,8 +153,8 @@ __global__ void ban_bad_words(T *logits, if (bi < 0) return; if (bi >= bs) return; int tid = threadIdx.x; - T *logits_now = logits + token_idx * length; - const int64_t *bad_tokens_now = bad_tokens + bi * bad_words_length; + T* logits_now = logits + token_idx * length; + const int64_t* bad_tokens_now = bad_tokens + bi * bad_words_length; const int32_t bad_token_len = static_cast(min(bad_tokens_len[bi], bad_words_length)); for (int i = tid; i < bad_token_len; i += blockDim.x) { @@ -166,21 +166,21 @@ __global__ void ban_bad_words(T *logits, template void token_penalty_multi_scores_kernel( - const paddle::Tensor &token_ids_all, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_score, - const paddle::Tensor &presence_score, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &bad_tokens_len, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token_output, - const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_score, + const paddle::Tensor& presence_score, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& bad_tokens_len, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; @@ -198,8 +198,7 @@ void token_penalty_multi_scores_kernel( int64_t end_length = eos_token_id.shape()[0]; int block_size = (token_num + 32 - 1) / 32 * 32; min_length_logits_process<<<1, block_size, 0, cu_stream>>>( - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast(const_cast(logits.data())), cur_len.data(), min_len.data(), eos_token_id.data(), @@ -230,15 +229,15 @@ void token_penalty_multi_scores_kernel( update_value_by_repeat_times <<>>( repeat_times.data(), - reinterpret_cast( - const_cast(penalty_scores.data())), - reinterpret_cast( - const_cast(frequency_score.data())), - reinterpret_cast( - const_cast(presence_score.data())), + reinterpret_cast( + const_cast(penalty_scores.data())), + reinterpret_cast( + const_cast(frequency_score.data())), + reinterpret_cast( + const_cast(presence_score.data())), temperatures.data(), - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast( + const_cast(logits.data())), batch_id_per_token_output.data(), token_num, bs, @@ -247,8 +246,7 @@ void token_penalty_multi_scores_kernel( block_size = (length_bad_words + 32 - 1) / 32 * 32; block_size = min(block_size, 512); ban_bad_words<<>>( - reinterpret_cast( - const_cast(logits.data())), + reinterpret_cast(const_cast(logits.data())), bad_tokens.data(), bad_tokens_len.data(), batch_id_per_token_output.data(), @@ -260,21 +258,21 @@ void token_penalty_multi_scores_kernel( } void SpecTokenPenaltyMultiScores( - const paddle::Tensor &token_ids_all, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &logits, - const paddle::Tensor &penalty_scores, - const paddle::Tensor &frequency_scores, - const paddle::Tensor &presence_scores, - const paddle::Tensor &temperatures, - const paddle::Tensor &bad_tokens, - const paddle::Tensor &bad_tokens_len, - const paddle::Tensor &cur_len, - const paddle::Tensor &min_len, - const paddle::Tensor &eos_token_id, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &batch_id_per_token_output, - const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor& token_ids_all, + const paddle::Tensor& prompt_lens, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& bad_tokens_len, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token_output, + const paddle::Tensor& cu_seqlens_q_output, const int max_seq_len) { switch (logits.type()) { case paddle::DataType::BFLOAT16: { diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu index 18aa5d53d21..e620e914a25 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu @@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( int64_t* next_tokens, // [bs, tokens_per_step] const int* max_think_lens, // [bs] int* max_reply_lens, // [bs] - int64_t* step_idx, // [bs] + const int64_t* step_idx, // [bs] const int64_t* eos_token_ids, // [eos_len] int* limit_status, // [bs] int* accept_num, // [bs] @@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( int new_accept_num = original_accept_num; // 本 step 的 token offset 对应的绝对 step - const int64_t current_base_step = step_idx[bid] - original_accept_num + 1; + const int64_t current_base_step = step_idx[bid] + 1; for (int token_offset = 0; token_offset < original_accept_num; token_offset++) { @@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel( // inject_token_ids[0]) if (status == 0 && (current_step - 1) == - max_think_len) { // current_step - 1 是因为 speculate_verify 里 - // step_idx + 1 了 + max_think_len) { // current_step - 1 : 已输出 current_step-1 + // 个thinking token status = (inject_len > 0) ? 1 : done_status; } } else if (max_think_len == 0) { @@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel( } } - // 更新 step_idx / accept_num(被截断的 token 需要回退 - // step_idx) - const int discarded_tokens = original_accept_num - new_accept_num; - if (discarded_tokens > 0) { - step_idx[bid] -= discarded_tokens; - } - accept_num[bid] = new_accept_num; limit_status[bid] = status; max_reply_lens[bid] = max_reply_len; @@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength( const_cast(next_tokens.data()), max_think_lens.data(), const_cast(max_reply_lens.data()), - const_cast(step_idx.data()), + step_idx.data(), eos_token_ids.data(), const_cast(limit_status.data()), const_cast(accept_num.data()), diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h new file mode 100644 index 00000000000..dc2c6f399f4 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h @@ -0,0 +1,39 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define SPEC_LOGPROB_MAX_BSZ 512 +#define SPEC_LOGPROB_K 20 +#define MAX_DRAFT_TOKEN_NUM 6 + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (SPEC_LOGPROB_K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + long mtype; + // stop_flag, message_flag, bsz, batch_token_nums + int meta[3 + SPEC_LOGPROB_MAX_BSZ]; + batch_msgdata mtext[SPEC_LOGPROB_MAX_BSZ]; +}; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu index 76a84f30d4f..eadcf015f70 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu @@ -184,24 +184,65 @@ void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, real_bsz); } +template +__global__ void compute_cu_batch_offset_kernel(int* cu_batch_token_offset, + const int* accept_num, + const int real_bsz) { + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + int tid = threadIdx.x; + if (tid == 0) cu_batch_token_offset[0] = 0; + + int thread_data[ITEMS_PER_THREAD]; + + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int batch_id = tid * ITEMS_PER_THREAD + i; + thread_data[i] = + batch_id < real_bsz ? accept_num[tid * ITEMS_PER_THREAD + i] : 0; + } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + int batch_id = tid * ITEMS_PER_THREAD + i; + if (batch_id < real_bsz) { + cu_batch_token_offset[batch_id + 1] = thread_data[i]; + } + } +} + template -__global__ void speculate_get_target_logits_kernel( +__global__ void speculate_get_accept_tokens_and_logits_kernel( + int64_t* token_ids, float* target_logits, const float* logits, const int* cu_batch_token_offset, - const int* ori_cu_batch_token_offset, + const int* cu_seqlens_q_output, const int* seq_lens_this_time, const int* seq_lens_encoder, const int* accept_num, + const int64_t* accept_tokens, const int vocab_size, + const int max_draft_tokens, const int real_bsz) { AlignedVector src_vec; const int bid = blockIdx.x; const int tid = threadIdx.x; if (bid < real_bsz) { + // get token_ids + if (tid == 0) { + auto* accept_tokens_now = accept_tokens + bid * max_draft_tokens; + for (int i = 0; i < accept_num[bid]; i++) { + token_ids[cu_batch_token_offset[bid] + i] = accept_tokens_now[i]; + } + } + + // get output_logits auto* target_logits_now = target_logits + cu_batch_token_offset[bid] * vocab_size; - auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + cu_seqlens_q_output[bid] * vocab_size; for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { if (seq_lens_encoder[bid] > 0) { Load(&logits_now[i], &src_vec); @@ -217,31 +258,64 @@ __global__ void speculate_get_target_logits_kernel( } } -void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, - const paddle::Tensor& logits, - const paddle::Tensor& cu_batch_token_offset, - const paddle::Tensor& ori_cu_batch_token_offset, - const paddle::Tensor& seq_lens_this_time, - const paddle::Tensor& seq_lens_encoder, - const paddle::Tensor& accept_num) { +void SpeculateGetAcceptTokensAndLogits( + const paddle::Tensor& token_ids, + const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num, + const paddle::Tensor& accept_tokens) { auto cu_stream = seq_lens_this_time.stream(); const int vocab_size = logits.shape()[1]; - const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_occupied_slots = seq_lens_this_time.shape()[0]; + const int max_draft_tokens = accept_tokens.shape()[1]; + + const int BLOCK_DIM = 512; + PADDLE_ENFORCE_LE(max_occupied_slots, + 2048, + phi::errors::InvalidArgument( + "Only support bsz <= 2048, but received bsz is ", + max_occupied_slots)); + if (max_occupied_slots <= 512) { + compute_cu_batch_offset_kernel + <<<1, BLOCK_DIM, 0, cu_stream>>>( + const_cast(cu_batch_token_offset.data()), + accept_num.data(), + max_occupied_slots); + } else if (max_occupied_slots <= 1024) { + compute_cu_batch_offset_kernel + <<<1, BLOCK_DIM, 0, cu_stream>>>( + const_cast(cu_batch_token_offset.data()), + accept_num.data(), + max_occupied_slots); + } else if (max_occupied_slots <= 2048) { + compute_cu_batch_offset_kernel + <<<1, BLOCK_DIM, 0, cu_stream>>>( + const_cast(cu_batch_token_offset.data()), + accept_num.data(), + max_occupied_slots); + } constexpr int PackSize = VEC_16B / sizeof(float); - dim3 grid_dim(real_bsz); + dim3 grid_dim(max_occupied_slots); dim3 block_dim(128); - speculate_get_target_logits_kernel + speculate_get_accept_tokens_and_logits_kernel <<>>( + const_cast(token_ids.data()), const_cast(target_logits.data()), logits.data(), cu_batch_token_offset.data(), - ori_cu_batch_token_offset.data(), + cu_seqlens_q_output.data(), seq_lens_this_time.data(), seq_lens_encoder.data(), accept_num.data(), + accept_tokens.data(), vocab_size, - real_bsz); + max_draft_tokens, + max_occupied_slots); } PD_BUILD_STATIC_OP(speculate_get_logits) @@ -274,14 +348,20 @@ PD_BUILD_STATIC_OP(speculate_insert_first_token) .SetInplaceMap({{"token_ids", "token_ids_out"}}) .SetKernelFn(PD_KERNEL(SpeculateInsertFirstToken)); -PD_BUILD_STATIC_OP(speculate_get_target_logits) - .Inputs({"target_logits", +PD_BUILD_STATIC_OP(speculate_get_accept_tokens_and_logits) + .Inputs({"token_ids", + "target_logits", "logits", "cu_batch_token_offset", - "ori_cu_batch_token_offset", + "cu_seqlens_q_output", "seq_lens_this_time", "seq_lens_encoder", - "accept_num"}) - .Outputs({"target_logits_out"}) - .SetInplaceMap({{"target_logits", "target_logits_out"}}) - .SetKernelFn(PD_KERNEL(SpeculateGetTargetLogits)); + "accept_num", + "accept_tokens"}) + .Outputs({"token_ids_out", + "target_logits_out", + "cu_batch_token_offset_out"}) + .SetInplaceMap({{"token_ids", "token_ids_out"}, + {"target_logits", "target_logits_out"}, + {"cu_batch_token_offset", "cu_batch_token_offset_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetAcceptTokensAndLogits)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc index 2a040a7e7b4..f72f3774107 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc @@ -36,8 +36,9 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens, int msg_queue_id, int save_each_rank, bool skip_prefill) { - // printf("enter save output"); - if (!save_each_rank && rank_id > 0) { + // NOTE(yaohuicong): Skip non-zero TP ranks — they share identical sampling + // outputs, so only rank 0 needs to send results to the message queue. + if (rank_id > 0) { return; } diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc index 53e822e6223..a11897b7ff3 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -19,27 +19,12 @@ #include #include "paddle/extension.h" #include "../custom_ftok.h" +#include "speculate_logprob_msg.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -#define MAX_BSZ 512 -#define K 20 -#define MAX_DRAFT_TOKEN_NUM 6 - -struct batch_msgdata { - int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; - int ranks[MAX_DRAFT_TOKEN_NUM]; -}; - -struct msgdata { - long mtype; - int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums - batch_msgdata mtext[MAX_BSZ]; -}; - void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, const paddle::Tensor& logprob_token_ids, const paddle::Tensor& logprob_scores, @@ -53,7 +38,9 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, int message_flag, // Target: 3, Draft: 4 int64_t rank_id, bool save_each_rank) { - if (!save_each_rank && rank_id > 0) { + // NOTE(yaohuicong): Skip non-zero TP ranks — they share identical sampling + // outputs, so only rank 0 needs to send results to the message queue. + if (rank_id > 0) { return; } @@ -134,9 +121,11 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, msg_sed.mtype = 1; msg_sed.meta[0] = not_need_stop.data()[0] ? inference_msg_id_from_env : -inference_msg_id_from_env; - msg_sed.meta[1] = message_flag; - msg_sed.meta[2] = bsz; + // Pack message_flag (low 8 bits) and max_num_logprobs (high 24 bits) into + // meta[1]. Receiver unpacks both to avoid reading unused topk slots. int max_num_logprobs = logprob_token_ids.shape()[1]; + msg_sed.meta[1] = message_flag | (max_num_logprobs << 8); + msg_sed.meta[2] = bsz; for (int i = 0; i < bsz; i++) { int cur_token_num; if (seq_lens_decoder_data[i] < prompt_lens_data[i]) { @@ -152,19 +141,20 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, auto* cur_batch_msg_sed = &msg_sed.mtext[i]; int token_offset = cu_batch_token_offset_data[i]; for (int j = 0; j < cur_token_num; j++) { - auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; - auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; - for (int k = 0; k < K + 1; k++) { + // Use SPEC_LOGPROB_K+1 as stride (fixed struct layout), but only write + // max_num_logprobs columns to avoid filling unused topk slots. + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; + for (int k = 0; k < max_num_logprobs; k++) { if (k == 0) { cur_tokens[k] = (int)sampled_token_ids_data[i * max_draft_tokens + j]; - cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; - } else if (k < max_num_logprobs) { - cur_tokens[k] = - (int)logprob_token_ids_data[(token_offset + j) * (K + 1) + k]; - cur_scores[k] = logprob_scores_data[(token_offset + j) * (K + 1) + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } else { - cur_tokens[k] = -1; - cur_scores[k] = 0.0; + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * max_num_logprobs + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * max_num_logprobs + k]; } } cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j]; @@ -173,22 +163,23 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, #ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG std::cout << "msg data: " << std::endl; std::cout << "stop_flag: " << msg_sed.meta[0] - << ", message_flag: " << msg_sed.meta[1] + << ", message_flag: " << (msg_sed.meta[1] & 0xFF) + << ", max_num_logprobs: " << (msg_sed.meta[1] >> 8) << ", bsz: " << msg_sed.meta[2] << std::endl; for (int i = 0; i < bsz; i++) { int cur_token_num = msg_sed.meta[3 + i]; auto* cur_batch_msg_sed = &msg_sed.mtext[i]; std::cout << "batch " << i << " token_num: " << cur_token_num << std::endl; for (int j = 0; j < cur_token_num; j++) { - auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; - auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)]; std::cout << "tokens: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << cur_tokens[k] << " "; } std::cout << std::endl; std::cout << "scores: "; - for (int k = 0; k < K + 1; k++) { + for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) { std::cout << cur_scores[k] << " "; } std::cout << std::endl; diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu index ee364884e96..c6379387efe 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu @@ -51,60 +51,65 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, const int64_t step_idx_now = step_idx[bid]; const int64_t min_token_limit = min_tokens[bid]; - const bool can_stop = (step_idx_now >= min_token_limit); + const bool can_stop = (step_idx_now + accept_num >= min_token_limit); if (!can_stop) return; if (!stop_flags[bid]) { - int accept_idx = 0; + /* + accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based) + accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾 + (pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。 + 为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1] + (当前轮最后一个 token),该 token 延迟到下一轮匹配。 + 循环范围:accept_num > 0 时为 [-1, accept_num-2]; + accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。 + */ + int accept_idx = -1; bool is_end = false; - // 遍历起始位置 - for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { + + // 统一检测:accept_idx = -1 对应上一轮延迟的最后 token 在 pre_ids 末尾 + // 完整匹配 stop_seqs 的情况;accept_idx >= 0 对应当前轮 accept_tokens + // 中的匹配。两者共享同一套从后向前匹配逻辑。 + int loop_end = (accept_num > 0) ? accept_num - 2 : -1; + for (; accept_idx <= loop_end && !is_end; accept_idx++) { if (step_idx_now + accept_idx + 1 < stop_seq_len) { #ifdef DEBUG_SPEC_STOP_SEQS printf("num %d < stop_seq_len %d\n", - step_idx_now - accept_num + accept_idx + 1, + step_idx_now + accept_idx + 1, stop_seq_len); #endif continue; } - // 遍历一个 stop_seqs + // 从后向前匹配 stop_seq 的每个 token for (int i = stop_seq_len - 1; i >= 0; --i) { int64_t cur_token_idx = -1; - // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 - if (stop_seq_len - 1 - i < accept_idx) { + int offset = stop_seq_len - 1 - i; + int accept_tokens_idx = accept_idx - offset; + + if (accept_tokens_idx >= 0) { #ifdef DEBUG_SPEC_STOP_SEQS printf( "AcceptTokens bid:%d. tid:%d, accept_idx:%d, " - "accept_token_idx: " - "%d\n", + "accept_token_idx: %d\n", bid, tid, accept_idx, - accept_idx - (stop_seq_len - 1 - i) - 1); + accept_tokens_idx); #endif - cur_token_idx = - accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]; + cur_token_idx = accept_tokens_now[accept_tokens_idx]; } else { + int pre_ids_idx = step_idx_now + accept_tokens_idx; #ifdef DEBUG_SPEC_STOP_SEQS printf( "PreIds bid:%d. tid:%d, step_idx_now:%ld. " - "accept_idx:%d. " - "pre_id_idx: %ld\n", + "accept_idx:%d. pre_id_idx: %d\n", bid, tid, step_idx_now, accept_idx, - step_idx_now - accept_num + accept_idx - - (stop_seq_len - 1 - i)); + pre_ids_idx); #endif - int pre_ids_idx = - step_idx_now + accept_idx - (stop_seq_len - 1 - i); - // EC3 - // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, - // 导致异常结束 - if (pre_ids_idx <= 0) { - break; - } + if (pre_ids_idx < 0) break; cur_token_idx = pre_ids_now[pre_ids_idx]; } #ifdef DEBUG_SPEC_STOP_SEQS @@ -126,12 +131,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, } if (is_end) { #ifdef DEBUG_SPEC_STOP_SEQS - printf("bid:%d end with accept_idx %d", bid, accept_idx); + printf("bid:%d end with accept_idx %d\n", bid, accept_idx); #endif - - accept_nums[bid] = accept_idx; - accept_tokens_now[accept_idx - 1] = end_ids[0]; - // stop_flags[bid] = true; + // accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置 + accept_nums[bid] = accept_idx + 1; + accept_tokens_now[accept_idx] = end_ids[0]; } } } diff --git a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu index 94f71d6fd0e..f7ab5daece6 100644 --- a/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu +++ b/custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu @@ -121,7 +121,7 @@ __global__ void unified_update_model_status_kernel(int *seq_lens_encoder, int64_t *token_ids_all_now = &token_ids_all[batch_id * max_model_len + prompt_len]; int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens]; - int64_t base = cur_step_idx - output_len + 1; + int64_t base = cur_step_idx - output_len; for (int i = 0; i < output_len; i++) { token_ids_all_now[base + i] = output_ids[i]; } diff --git a/custom_ops/gpu_ops/stop_generation_multi_ends.cu b/custom_ops/gpu_ops/stop_generation_multi_ends.cu index d2a6dcbbf60..06cf99831d7 100644 --- a/custom_ops/gpu_ops/stop_generation_multi_ends.cu +++ b/custom_ops/gpu_ops/stop_generation_multi_ends.cu @@ -79,8 +79,9 @@ __global__ void set_value_by_flags(bool *stop_flags, // dealing stop_seqs const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; if (stop_seq_len <= 0) return; - const int64_t *stop_seq_now = - stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; + const int64_t *stop_seq_now = stop_seqs + + bid * stop_seqs_bs * stop_seqs_max_len + + tid * stop_seqs_max_len; const int64_t *pre_ids_now = token_ids_all + bid * max_model_len + prompt_lens[bid]; const int64_t step_idx_now = step_idx[bid]; diff --git a/custom_ops/gpu_ops/swap_cache_layout.cu b/custom_ops/gpu_ops/swap_cache_layout.cu index 62adccb2d04..08f64197f9b 100644 --- a/custom_ops/gpu_ops/swap_cache_layout.cu +++ b/custom_ops/gpu_ops/swap_cache_layout.cu @@ -15,74 +15,264 @@ #include "helper.h" #include "paddle/extension.h" -// #define SWAP_DEBUG +// D2H: Each thread block handles ALL layers for one swap block. +// This produces perfectly contiguous host writes (1 block × all layers), +// maximizing write-combining efficiency. +template +__global__ void swap_d2h_kernel(T** __restrict__ layer_ptrs, + T* __restrict__ cpu_buffer, + const int64_t* __restrict__ gpu_block_ids, + int n_blocks, + int layer_num, + int64_t block_stride) { + int block_idx = blockIdx.x; + if (block_idx >= n_blocks) return; + + int64_t gpu_block = gpu_block_ids[block_idx]; + int64_t num_vec_per_layer = (block_stride * sizeof(T)) / sizeof(float4); + + T* dst_base = cpu_buffer + (int64_t)block_idx * layer_num * block_stride; + + for (int layer_idx = 0; layer_idx < layer_num; layer_idx++) { + const T* src = layer_ptrs[layer_idx] + gpu_block * block_stride; + float4* dst4 = + reinterpret_cast(dst_base + layer_idx * block_stride); + const float4* src4 = reinterpret_cast(src); + + for (int64_t i = threadIdx.x; i < num_vec_per_layer; i += blockDim.x) { + dst4[i] = src4[i]; + } + } +} + +// H2D: scatter from contiguous staging buffer to scattered GPU layer tensors +template +__global__ void scatter_blocks_kernel(T** __restrict__ layer_ptrs, + const T* __restrict__ staging, + const int64_t* __restrict__ gpu_block_ids, + int n_blocks, + int layer_num, + int64_t block_stride) { + int pair_idx = blockIdx.x; + int block_idx = pair_idx / layer_num; + int layer_idx = pair_idx % layer_num; + + if (block_idx >= n_blocks) return; + + int64_t gpu_block = gpu_block_ids[block_idx]; + const T* src = staging + (int64_t)block_idx * layer_num * block_stride + + layer_idx * block_stride; + T* dst = layer_ptrs[layer_idx] + gpu_block * block_stride; + + int64_t num_vec = (block_stride * sizeof(T)) / sizeof(float4); + const float4* src4 = reinterpret_cast(src); + float4* dst4 = reinterpret_cast(dst); + + for (int64_t i = threadIdx.x; i < num_vec; i += blockDim.x) { + dst4[i] = src4[i]; + } +} + +static void* g_staging_buffer = nullptr; +static size_t g_staging_buffer_size = 0; +static void* g_device_block_ids = nullptr; +static size_t g_device_block_ids_size = 0; +static void* g_device_layer_ptrs = nullptr; +static size_t g_device_layer_ptrs_size = 0; + +static void ensure_staging_buffer(size_t required_size) { + if (g_staging_buffer_size < required_size) { + if (g_staging_buffer) cudaFree(g_staging_buffer); + cudaError_t err = cudaMalloc(&g_staging_buffer, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc staging buffer failed: %s", + cudaGetErrorString(err))); + g_staging_buffer_size = required_size; + } +} + +static void ensure_device_block_ids(size_t required_size) { + if (g_device_block_ids_size < required_size) { + if (g_device_block_ids) cudaFree(g_device_block_ids); + cudaError_t err = cudaMalloc(&g_device_block_ids, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc device block_ids failed: %s", + cudaGetErrorString(err))); + g_device_block_ids_size = required_size; + } +} + +static void ensure_device_layer_ptrs(size_t required_size) { + if (g_device_layer_ptrs_size < required_size) { + if (g_device_layer_ptrs) cudaFree(g_device_layer_ptrs); + cudaError_t err = cudaMalloc(&g_device_layer_ptrs, required_size); + PADDLE_ENFORCE_EQ( + err, + cudaSuccess, + phi::errors::External("cudaMalloc device layer_ptrs failed: %s", + cudaGetErrorString(err))); + g_device_layer_ptrs_size = required_size; + } +} + +static bool is_cpu_block_ids_sequential( + const std::vector& cpu_block_ids) { + if (cpu_block_ids.empty()) return true; + int64_t start = cpu_block_ids[0]; + for (size_t i = 1; i < cpu_block_ids.size(); i++) { + if (cpu_block_ids[i] != start + static_cast(i)) return false; + } + return true; +} template -void SwapCacheImpLayout( - const std::vector& cache_gpu_tensors, // gpu - const int64_t& cache_cpu_pointer, // cpu - const std::vector& cache_shape, - const std::vector& gpu_block_ids, - const std::vector& cpu_block_ids, - int mode) { - /* - mode is 0: gpu to cpu; 1: cpu to gpu - - cache layout: layer_num * [block_num, head_num, block_size, head_dim] - scale layout: layer_num * [block_num, head_num, block_size] - cache buffer layout: [block_num, layer_num, head_num, block_size, head_dim] - scale buffer layout: [block_num, layer_num, head_num, block_size] - */ +void SwapCacheImpLayout(const std::vector& cache_gpu_tensors, + const int64_t& cache_cpu_pointer, + const std::vector& cache_shape, + const std::vector& gpu_block_ids, + const std::vector& cpu_block_ids, + int mode) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; const int64_t layer_number = cache_gpu_tensors.size(); int64_t cache_block_stride = 1; - for (int i = 1; i < cache_shape.size(); i++) { + for (size_t i = 1; i < cache_shape.size(); i++) { cache_block_stride *= cache_shape[i]; } + const int n_blocks = gpu_block_ids.size(); + if (n_blocks == 0) return; + auto stream = cache_gpu_tensors[0].stream(); - const cudaMemcpyKind copy_kind = - (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; - - for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) { - const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; - data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); - auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); - - for (int block_idx = 0; block_idx < gpu_block_ids.size(); block_idx++) { - auto cur_gpu_block_id = gpu_block_ids[block_idx]; - auto cur_cpu_block_id = cpu_block_ids[block_idx]; - auto* cache_gpu_ptr_now = - cache_gpu_ptr + cur_gpu_block_id * cache_block_stride; - auto* cache_cpu_ptr_now = - cache_cpu_ptr + cur_cpu_block_id * cache_block_stride * layer_number + - layer_idx * cache_block_stride; - - cudaError_t status = cudaMemcpyAsync( - (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now - : cache_gpu_ptr_now, - (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now - : cache_cpu_ptr_now, - cache_block_stride * sizeof(DataType_), - copy_kind, - stream); + const size_t block_bytes = cache_block_stride * sizeof(DataType_); + const size_t total_bytes = (size_t)n_blocks * layer_number * block_bytes; + + bool use_optimized = is_cpu_block_ids_sequential(cpu_block_ids); + + // float4 vectorized kernels require block_bytes to be 16-byte aligned + // and cache_cpu_base to be 16-byte aligned for correct float4 access. + if (use_optimized && (block_bytes % sizeof(float4) != 0)) { + use_optimized = false; + } + if (use_optimized) { + int64_t cpu_start_block = cpu_block_ids[0]; + uintptr_t cpu_base_addr = + static_cast(cache_cpu_pointer) + + cpu_start_block * layer_number * cache_block_stride * sizeof(DataType_); + if (cpu_base_addr % sizeof(float4) != 0) { + use_optimized = false; + } + } + if (use_optimized) { + ensure_device_block_ids(n_blocks * sizeof(int64_t)); + ensure_device_layer_ptrs(layer_number * sizeof(DataType_*)); + + cudaError_t status = cudaMemcpyAsync(g_device_block_ids, + gpu_block_ids.data(), + n_blocks * sizeof(int64_t), + cudaMemcpyHostToDevice, + stream); + PADDLE_ENFORCE_EQ( + status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync block_ids H2D failed: %s", + cudaGetErrorString(status))); + + std::vector h_layer_ptrs(layer_number); + for (int64_t i = 0; i < layer_number; i++) { + h_layer_ptrs[i] = reinterpret_cast( + const_cast(cache_gpu_tensors[i].data())); + } + status = cudaMemcpyAsync(g_device_layer_ptrs, + h_layer_ptrs.data(), + layer_number * sizeof(DataType_*), + cudaMemcpyHostToDevice, + stream); + PADDLE_ENFORCE_EQ( + status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync layer_ptrs H2D failed: %s", + cudaGetErrorString(status))); + + int64_t cpu_start_block = cpu_block_ids[0]; + auto* cache_cpu_base = reinterpret_cast(cache_cpu_pointer) + + cpu_start_block * layer_number * cache_block_stride; + + int grid_size = n_blocks * layer_number; + + if (mode == 0) { + // GPU→CPU: direct kernel write to pinned host memory + // Multi-layer kernel: each block handles all layers for one swap block + swap_d2h_kernel<<>>( + reinterpret_cast(g_device_layer_ptrs), + cache_cpu_base, + reinterpret_cast(g_device_block_ids), + n_blocks, + layer_number, + cache_block_stride); + } else { + // CPU→GPU: DMA memcpy to staging then scatter kernel + ensure_staging_buffer(total_bytes); + + status = cudaMemcpyAsync(g_staging_buffer, + cache_cpu_base, + total_bytes, + cudaMemcpyHostToDevice, + stream); PADDLE_ENFORCE_EQ(status, cudaSuccess, - phi::errors::External("cudaMemcpyAsync failed: %s", + phi::errors::External("cudaMemcpyAsync H2D failed: %s", cudaGetErrorString(status))); -#ifdef SWAP_DEBUG - cudaStreamSynchronize(stream); - std::cout << "mode:" << mode << ", layer_idx:" << layer_idx - << ", block_idx:" << block_idx << ", cache_cpu_ptr_now data:" - << static_cast(*cache_cpu_ptr_now) << std::endl; -#endif + scatter_blocks_kernel<<>>( + reinterpret_cast(g_device_layer_ptrs), + reinterpret_cast(g_staging_buffer), + reinterpret_cast(g_device_block_ids), + n_blocks, + layer_number, + cache_block_stride); + } + } else { + const cudaMemcpyKind copy_kind = + (mode == 0) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; + for (int64_t layer_idx = 0; layer_idx < layer_number; layer_idx++) { + const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; + data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); + auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); + + for (int block_idx = 0; block_idx < n_blocks; block_idx++) { + auto cur_gpu_block_id = gpu_block_ids[block_idx]; + auto cur_cpu_block_id = cpu_block_ids[block_idx]; + auto* cache_gpu_ptr_now = + cache_gpu_ptr + cur_gpu_block_id * cache_block_stride; + auto* cache_cpu_ptr_now = + cache_cpu_ptr + + cur_cpu_block_id * cache_block_stride * layer_number + + layer_idx * cache_block_stride; + + cudaError_t status = cudaMemcpyAsync( + (copy_kind == cudaMemcpyDeviceToHost) ? cache_cpu_ptr_now + : cache_gpu_ptr_now, + (copy_kind == cudaMemcpyDeviceToHost) ? cache_gpu_ptr_now + : cache_cpu_ptr_now, + block_bytes, + copy_kind, + stream); + PADDLE_ENFORCE_EQ(status, + cudaSuccess, + phi::errors::External("cudaMemcpyAsync failed: %s", + cudaGetErrorString(status))); + } } } + cudaError_t sync_status = cudaStreamSynchronize(stream); PADDLE_ENFORCE_EQ(sync_status, cudaSuccess, @@ -90,15 +280,14 @@ void SwapCacheImpLayout( cudaGetErrorString(sync_status))); } -void SwapCacheLayout( - const std::vector& cache_gpu_tensors, // gpu - const int64_t& cache_cpu_ptrs, // cpu memory pointer - const std::vector& cache_shape, - const std::vector& gpu_block_ids, - const std::vector& cpu_block_ids, - int rank, - int mode) { - cudaSetDevice(rank); // used for distributed launch +void SwapCacheLayout(const std::vector& cache_gpu_tensors, + const int64_t& cache_cpu_ptrs, + const std::vector& cache_shape, + const std::vector& gpu_block_ids, + const std::vector& cpu_block_ids, + int rank, + int mode) { + cudaSetDevice(rank); assert(cache_gpu_tensors.size() > 0); switch (cache_gpu_tensors[0].dtype()) { case paddle::DataType::BFLOAT16: diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 180116bf2c7..1ba47905f1f 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -174,6 +174,12 @@ def get_gencode_flags(archs): "-gencode", f"arch=compute_{arch_code},code=sm_{arch_code}", ] + elif cc_val == 103: + arch_code = "103a" + flags += [ + "-gencode", + f"arch=compute_{arch_code},code=sm_{arch_code}", + ] else: flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"] return flags @@ -237,6 +243,7 @@ def find_end_files(directory, end_str): "gpu_ops/set_data_ipc.cu", "gpu_ops/unset_data_ipc.cu", "gpu_ops/moe/tritonmoe_preprocess.cu", + "gpu_ops/moe/moe_align_kernel.cu", "gpu_ops/step_system_cache.cu", "gpu_ops/get_output_ep.cc", "gpu_ops/speculate_decoding/speculate_get_padding_offset.cu", @@ -250,7 +257,7 @@ def find_end_files(directory, end_str): "gpu_ops/speculate_decoding/speculate_step.cu", "gpu_ops/speculate_decoding/speculate_step_system_cache.cu", "gpu_ops/speculate_decoding/speculate_update_v3.cu", - "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", + "gpu_ops/get_position_ids.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/step_reschedule.cu", ] @@ -326,10 +333,13 @@ def find_end_files(directory, end_str): "gpu_ops/sample_kernels/rejection_top_p_sampling.cu", "gpu_ops/sample_kernels/top_k_renorm_probs.cu", "gpu_ops/sample_kernels/min_p_sampling_from_probs.cu", - "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", + "gpu_ops/get_position_ids.cu", + "gpu_ops/get_position_ids_and_slot_mapping.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/grouped_topk_kernels.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/custom_all_reduce/all_reduce.cu", "gpu_ops/merge_prefill_decode_output.cu", "gpu_ops/limit_thinking_content_length.cu", @@ -470,58 +480,60 @@ def find_end_files(directory, end_str): # This script seems general enough for different SM versions, specific templates are chosen by CUTLASS. os.system("python utils/auto_gen_visitor_fp8_gemm_fused_kernels.py") - if cc >= 90: # Hopper and newer - # SM90 (Hopper) specific auto-generation and flags - if cc == 90: # Only for SM90 - nvcc_compile_args += [ - # The gencode for 90a is added in get_gencode_flags now - # "-gencode", - # "arch=compute_90a,code=compute_90a", - "-O3", - "-DNDEBUG", # NDEBUG is common, consider moving if not specific to 90a - ] - print("SM90: Running SM90-specific FP8 kernel auto-generation.") - os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") - os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py") - os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py") - - nvcc_compile_args += [ - "-DENABLE_SCALED_MM_SM90=1", - ] - sources += [ - "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu", - "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", - "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", - ] - elif cc == 100 and nvcc_version >= 12.9: # Blackwell SM100 specifics - print("SM100 (Blackwell): Applying SM100 configurations.") - nvcc_compile_args += [ - # The gencode for 100a is added in get_gencode_flags - # "-gencode", - # "arch=compute_100a,code=compute_100a", - "-O3", # Common optimization flag - "-DNDEBUG", # Common debug flag - # Potentially add -DENABLE_SM100_FEATURES if specific macros are identified - ] - # Placeholder for SM100-specific kernel auto-generation scripts - # These might be needed if Blackwell has new FP8 hardware features - # not covered by existing generic CUTLASS templates or SM90 scripts. - # print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).") - # os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example - # os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example - - # Add SM100 specific sources if any, e.g., for new hardware intrinsics - # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example - pass # No SM100 specific sources identified yet beyond what CUTLASS handles - else: # For cc >= 89 but not 90 or 100 (e.g. SM89) - print(f"SM{cc}: Running generic FP8 kernel auto-generation.") - os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") - os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") - - else: # For cc == 89 (Ada) - print("SM89: Running generic FP8 kernel auto-generation.") + # Use non-exclusive checks against sm_versions so that building for + # multiple architectures (e.g. [80,90,100]) compiles kernels for ALL + # of them instead of only the highest one. + has_sm90 = 90 in sm_versions + has_sm100 = 100 in sm_versions and nvcc_version >= 12.9 + has_sm103 = 103 in sm_versions and nvcc_version >= 13.0 + has_generic_fp8 = not has_sm90 and not has_sm100 and not has_sm103 # SM89 or other + + if has_sm90 or has_sm100 or has_sm103: + nvcc_compile_args += [ + "-O3", + "-DNDEBUG", + ] + + if has_sm90: + print("SM90: Running SM90-specific FP8 kernel auto-generation.") + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm90.py") + os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm90.py") + os.system("python utils/auto_gen_fp8_fp8_block_gemm_fused_kernels_sm90.py") + + nvcc_compile_args += [ + "-DENABLE_SCALED_MM_SM90=1", + ] + sources += [ + "gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu", + "gpu_ops/cutlass_kernels/w8a8/scaled_mm_c3x_sm90.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_fp8.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_sm90_int8.cu", + "gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu", + ] + + if has_sm100 or has_sm103: + print("SM100 / 103 (Blackwell): Applying SM100 / SM103 configurations.") + # Placeholder for SM100-specific kernel auto-generation scripts + # These might be needed if Blackwell has new FP8 hardware features + # not covered by existing generic CUTLASS templates or SM90 scripts. + # print("SM100: Running SM100-specific FP8 kernel auto-generation (if any).") + # os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels_sm100.py") # Example + # os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels_sm100.py") # Example + + # Add SM100 specific sources if any, e.g., for new hardware intrinsics + # sources += ["gpu_ops/cutlass_kernels/w8a8/c4x_sm100.cu"] # Example + pass # No SM100 specific sources identified yet beyond what CUTLASS handles + + if has_generic_fp8: + # For SM89 (Ada) or other architectures without dedicated paths + print(f"SM{cc}: Running generic FP8 kernel auto-generation.") + os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") + os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") + + if not has_sm90 and cc >= 90: + # When cc >= 90 but SM90 is not in the target list (e.g. only [80,100]), + # still run generic FP8 auto-generation for non-SM90 paths. + print(f"SM{cc}: Running generic FP8 kernel auto-generation (no SM90 target).") os.system("python utils/auto_gen_fp8_fp8_gemm_fused_kernels.py") os.system("python utils/auto_gen_fp8_fp8_dual_gemm_fused_kernels.py") @@ -542,6 +554,13 @@ def find_end_files(directory, end_str): sources += find_end_files(fp8_auto_gen_directory, ".cu") if cc >= 90 and nvcc_version >= 12.0: + # decode unified attention + os.system( + "python utils/auto_gen_template_attention.py --config gpu_ops/decode_unified_attention/template_config.json --output gpu_ops/decode_unified_attention/template_instantiation/autogen" + ) + sources += ["gpu_ops/decode_unified_attention.cu"] + sources += ["gpu_ops/decoder_write_cache_with_rope.cu"] + sources += find_end_files("gpu_ops/decode_unified_attention", ".cu") # Hopper optimized mla sources += find_end_files("gpu_ops/mla_attn", ".cu") sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"] @@ -683,15 +702,18 @@ def find_end_files(directory, end_str): "gpu_ops/recover_decode_task.cu", "gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc_redundant.cu", + "gpu_ops/grouped_topk_kernels.cu", + "gpu_ops/fused_cast_sigmoid_bias.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/text_image_index_out.cu", - "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", + "gpu_ops/get_position_ids.cu", "gpu_ops/limit_thinking_content_length.cu", "gpu_ops/update_attn_mask_offsets.cu", "gpu_ops/append_attn/mla_cache_kernel.cu", "gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu", "gpu_ops/moe/tritonmoe_preprocess.cu", + "gpu_ops/moe/moe_align_kernel.cu", "gpu_ops/moe/moe_topk_select.cu", "gpu_ops/get_img_boundaries.cc", "gpu_ops/remote_cache_kv_ipc.cc", diff --git a/custom_ops/utils/auto_gen_template_attention.py b/custom_ops/utils/auto_gen_template_attention.py new file mode 100644 index 00000000000..5658f6645e7 --- /dev/null +++ b/custom_ops/utils/auto_gen_template_attention.py @@ -0,0 +1,227 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Universal template instantiation generator - fully based on configuration file template instantiation generation.""" + +import argparse +import json +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class TemplateConfig: + """Template configuration class.""" + + name: str # Function name + function_name: str # Actual function name + impl_file: str # Implementation file path + template_params: List[str] # Template parameter list (in order) + dispatch_params: Dict[str, List[Any]] # Dispatch parameters + data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix) + max_instances_per_file: int = 60 # Maximum instances per file + file_prefix: str = "" # File prefix + function_signature: str = "" # Function signature template + + +class UniversalTemplateInstantiator: + """Universal template instantiator - fully based on configuration file.""" + + def __init__(self, config_file: str): + """Initialize the instantiator.""" + self.config_file = config_file + self.configs = self._load_configs() + + def _load_configs(self) -> Dict[str, TemplateConfig]: + """Load configuration file.""" + with open(self.config_file, "r", encoding="utf-8") as f: + config_data = json.load(f) + + configs = {} + for name, config_dict in config_data.items(): + config = TemplateConfig(**config_dict) + self._validate_config(config) + configs[name] = config + return configs + + def _validate_config(self, config: TemplateConfig): + """Validate configuration completeness.""" + for param_name in config.template_params: + if param_name not in config.dispatch_params: + raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params") + + def _build_template_args(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Build template arguments.""" + template_args_parts = [] + + for param_name in config.template_params: + if param_name in params: + template_args_parts.append(str(params[param_name])) + + else: + raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params") + + return f"<{', '.join(template_args_parts)}>" + + def _build_params_template_args(self, params: Dict[str, Any]) -> str: + """Build template arguments for AttentionParams.""" + params_template_args = [] + if "T" in params: + params_template_args.append(str(params["T"])) + else: + raise ValueError("Template parameter 'T' not found in dispatch_params") + + if "CacheT" in params: + params_template_args.append(str(params["CacheT"])) + else: + # C16 kernels use AttentionParams - T is repeated for both args + params_template_args.append(str(params["T"])) + + return f"<{', '.join(params_template_args)}>" + + def _generate_function_signature( + self, config: TemplateConfig, template_args: str, params_template_args: str + ) -> str: + """Generate function signature.""" + if config.function_signature: + signature = config.function_signature.format( + function_name=config.function_name, + template_args=template_args, + params_template_args=params_template_args, + ) + + return signature + else: + raise ValueError(f"Function signature not found for {config.name}") + + def _generate_file_header(self, config: TemplateConfig) -> str: + """Generate file header.""" + return f"""// Generated by autogen_template_instantiation.py - Do not edit. + +#pragma once + +#include "../../{config.impl_file}" +""" + + def _generate_template_instantiation(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Generate template instantiation.""" + template_args = self._build_template_args(config, params) + params_template_args = self._build_params_template_args(params) + return self._generate_function_signature(config, template_args, params_template_args) + + def _clean_output_directory(self, output_dir: str): + """Clean output directory before generating new files.""" + output_path = Path(output_dir) + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + def generate_combinations_for_type(self, config: TemplateConfig) -> List[Dict[str, Any]]: + """Generate parameter combinations for specific type.""" + combinations = [] + + def _generate_recursive( + params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str] + ): + if not param_names: + combinations.append(current_params.copy()) + return + + param_name = param_names[0] + for value in params_dict[param_name]: + current_params[param_name] = value + _generate_recursive(params_dict, current_params, param_names[1:]) + + _generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys())) + + return combinations + + def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]: + """Split combinations into multiple files.""" + chunks = [] + for i in range(0, len(combinations), max_per_file): + chunk = combinations[i : i + max_per_file] + chunks.append(chunk) + return chunks + + def generate_file_content( + self, + config: TemplateConfig, + file_index: int, + combinations: List[Dict[str, Any]], + ) -> str: + """Generate file content.""" + content = self._generate_file_header(config) + + for params in combinations: + content += self._generate_template_instantiation(config, params) + + return content + + def generate_for_function_type(self, function_name: str, output_dir: str): + """Generate template instantiation files for specific function type.""" + if function_name not in self.configs: + raise ValueError(f"Function type '{function_name}' not found in config") + + config = self.configs[function_name] + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + combinations = self.generate_combinations_for_type(config) + if combinations: + chunks = self.split_combinations(combinations, config.max_instances_per_file) + for i, chunk in enumerate(chunks): + filename = f"{config.file_prefix}_part_{i:02d}.cu" + filepath = output_path / filename + content = self.generate_file_content(config, i, chunk) + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + def generate_all(self, output_dir: str): + """Generate all configured function types.""" + self._clean_output_directory(output_dir) + for function_name in self.configs.keys(): + print(f"Generating template instantiations for {function_name}...") + self.generate_for_function_type(function_name, output_dir) + print(f"Completed generating {function_name} template instantiations.") + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="Universal template instantiation generator") + parser.add_argument( + "--config", + "-c", + type=str, + help="Configuration file path (JSON format)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + help="Output directory", + ) + + args = parser.parse_args() + + try: + instantiator = UniversalTemplateInstantiator(args.config) + instantiator.generate_all(args.output) + except Exception as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + main() diff --git a/custom_ops/xpu_ops/download_dependencies.sh b/custom_ops/xpu_ops/download_dependencies.sh index ad6d4d2dea6..b927448ccc2 100644 --- a/custom_ops/xpu_ops/download_dependencies.sh +++ b/custom_ops/xpu_ops/download_dependencies.sh @@ -15,8 +15,8 @@ if [ "$1" == "stable" ]; then version_xvllm="20251017" version_xtdk="3.4.0.1" else - version_xvllm="latest" - version_xtdk="latest" + version_xvllm="20260407" + version_xtdk="3.6.2.1" fi ( diff --git a/custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc b/custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc index 04d8efe71e7..cb50725fdbc 100644 --- a/custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc +++ b/custom_ops/xpu_ops/src/ops/get_output_msg_with_topk.cc @@ -82,13 +82,17 @@ void GetOutputTopK(const paddle::Tensor& x, return; } - int bsz = msg_rcv.mtext[1]; + // Unpack bsz (low 16 bits) and actual_topk (high 16 bits) from mtext[1]. + // This matches the packing in save_output_msg_with_topk.cc: + // mtext[1] = bsz | (max_num_logprobs << 16) + int bsz = msg_rcv.mtext[1] & 0xFFFF; + int actual_topk = (msg_rcv.mtext[1] >> 16) & 0xFFFF; out_data[0] = (int64_t)msg_rcv.mtext[0]; - out_data[1] = (int64_t)msg_rcv.mtext[1]; + out_data[1] = (int64_t)msg_rcv.mtext[1]; // keep packed value; Python unpacks for (int i = 0; i < bsz; i++) { - for (int j = 0; j < k + 1; j++) { - const int64_t offset = i * (K + 1) + j; + for (int j = 0; j < actual_topk; j++) { + const int64_t offset = i * actual_topk + j; out_data[offset + 2] = (int64_t)msg_rcv.mtext[offset + 2]; scores_data[offset] = msg_rcv.mtext_f[offset]; } diff --git a/custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc b/custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc index 455e0fa18fb..154affbbde6 100644 --- a/custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc +++ b/custom_ops/xpu_ops/src/ops/save_output_msg_with_topk.cc @@ -109,20 +109,21 @@ void SaveOutMmsgTopK(const paddle::Tensor& x, : -inference_msg_id_from_env; int bsz = x.shape()[0]; int max_num_logprobs = logprob_token_ids.shape()[1]; - msg_sed.mtext[1] = bsz; + // Pack bsz (low 16 bits) and max_num_logprobs (high 16 bits) into mtext[1]. + // token_processor unpacks both fields to avoid reading unused topk slots. + msg_sed.mtext[1] = bsz | (max_num_logprobs << 16); for (int i = 0; i < bsz; i++) { - for (int j = 0; j < K + 1; j++) { - const int64_t offset = i * (K + 1) + j; + // Loop only over actual logprob columns (max_num_logprobs) instead of the + // fixed K+1=21, and use max_num_logprobs as the stride so data is packed + // densely in the message buffer. + for (int j = 0; j < max_num_logprobs; j++) { + const int64_t offset = i * max_num_logprobs + j; if (j == 0) { msg_sed.mtext[offset + 2] = (int)x_data[i]; - msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j]; - } else if (j < max_num_logprobs) { - msg_sed.mtext[offset + 2] = - (int)logprob_token_ids_data[i * max_num_logprobs + j]; - msg_sed.mtext_f[offset] = logprob_scores_data[i * max_num_logprobs + j]; + msg_sed.mtext_f[offset] = logprob_scores_data[offset]; } else { - msg_sed.mtext[offset + 2] = -1; - msg_sed.mtext_f[offset] = 0.0; + msg_sed.mtext[offset + 2] = (int)logprob_token_ids_data[offset]; + msg_sed.mtext_f[offset] = logprob_scores_data[offset]; } if (preempted_idx_data[i] == 1) { msg_sed.mtext[offset + 2] = -9; diff --git a/dockerfiles/Dockerfile.gpu b/dockerfiles/Dockerfile.gpu index 5ce8b05b199..4a4240cd76a 100644 --- a/dockerfiles/Dockerfile.gpu +++ b/dockerfiles/Dockerfile.gpu @@ -1,6 +1,6 @@ FROM ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:tag-base -ARG PADDLE_VERSION=3.3.0 -ARG FD_VERSION=2.4.0 +ARG PADDLE_VERSION=3.3.1 +ARG FD_VERSION=2.5.0 ENV DEBIAN_FRONTEND=noninteractive @@ -16,8 +16,8 @@ RUN python -m pip uninstall paddlepaddle-gpu fastdeploy-gpu -y RUN python -m pip install --no-cache-dir paddlepaddle-gpu==${PADDLE_VERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ # build and install FastDeploy -RUN python -m pip install --no-cache-dir fastdeploy-gpu==${FD_VERSION} -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +RUN python -m pip install --no-cache-dir fastdeploy-gpu==${FD_VERSION} -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple ENV http_proxy="" ENV https_proxy="" -ENV no_proxy="" +ENV no_proxy="" \ No newline at end of file diff --git a/dockerfiles/Dockerfile.xpu b/dockerfiles/Dockerfile.xpu index 14998860a12..3b98165dd0c 100644 --- a/dockerfiles/Dockerfile.xpu +++ b/dockerfiles/Dockerfile.xpu @@ -15,7 +15,7 @@ RUN python -m pip uninstall paddlepaddle-gpu paddlepaddle-xpu fastdeploy-xpu -y RUN python -m pip uninstall -y Pillow && rm -rf /usr/local/lib/python3.10/dist-packages/Pillow* && rm -rf /usr/local/lib/python3.10/dist-packages/pillow* && python -m pip install Pillow==11.3.0 # install paddlepaddle-xpu -ARG PADDLE_VERSION=nightly +ARG PADDLE_VERSION=3.3.1 RUN if [ "$PADDLE_VERSION" = "nightly" ]; then \ python -m pip install --no-cache-dir --progress-bar off paddlepaddle-xpu -i https://www.paddlepaddle.org.cn/packages/nightly/xpu-p800/; \ @@ -26,7 +26,7 @@ RUN if [ "$PADDLE_VERSION" = "nightly" ]; then \ # install fastdeploy-xpu ARG INSTALL_REQUIREMENTS=true ARG INSTALL_FASTDEPLOY=true -ARG FASTDEPLOY_VERSION=2.4.0 +ARG FASTDEPLOY_VERSION=2.5.0 RUN if [ "$INSTALL_FASTDEPLOY" = "true" ]; then \ python -m pip install --no-cache-dir fastdeploy-xpu==${FASTDEPLOY_VERSION} -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-xpu-p800/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple; \ @@ -40,4 +40,4 @@ RUN mkdir -p /workspace/deps && cd /workspace/deps && \ wget https://klx-sdk-release-public.su.bcebos.com/xre/kl3-release/5.0.21.21/xre-Linux-x86_64-5.0.21.21.tar.gz && \ tar -zxf xre-Linux-x86_64-5.0.21.21.tar.gz && mv xre-Linux-x86_64-5.0.21.21 xre -ENV PATH=/workspace/deps/xre/bin:$PATH +ENV PATH=/workspace/deps/xre/bin:$PATH \ No newline at end of file diff --git a/docs/benchmark.md b/docs/benchmark.md index 1a2e6f88031..7abdd68aac1 100644 --- a/docs/benchmark.md +++ b/docs/benchmark.md @@ -40,3 +40,100 @@ python benchmark_serving.py \ --max-concurrency 1 \ --save-result ``` + +## In-Process Benchmark Metrics Logger + +FastDeploy provides a built-in performance monitoring module that runs inside the inference process. It collects per-request timing data and computes rolling statistics aligned with `benchmark_serving.py`, writing results to a JSONL file for real-time monitoring and post-hoc analysis. + +### Enable + +Add `--benchmark-metrics-config` with a JSON string to the service startup command: + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model baidu/ERNIE-4.5-0.3B-Base-Paddle \ + --benchmark-metrics-config '{"enable": true}' +``` + +### Configuration Parameters + +| Parameter | Type | Default | Description | +| :-------- | :--- | :------ | :---------- | +| `enable` | bool | `false` | Whether to enable the benchmark metrics logger. Must be set to `true` to activate. | +| `window_size` | int | `0` | Number of recent requests to aggregate. `0` = cumulative (all requests since start). | +| `window_mode` | str | `"sliding"` | Window aggregation mode. `"sliding"` = sliding window (keeps last N records, oldest automatically dropped). `"tumbling"` = tumbling window (clears and restarts after every N records). | +| `percentiles` | str | `"50,90,95,99"` | Comma-separated percentile values to compute. | +| `metrics` | str | `"all"` | Comma-separated metric names to report, or `"all"` for all metrics. | + +### Available Metrics + +Metrics are aligned with `benchmark_serving.py --percentile-metrics`: + +| Metric Name | Description | Unit | +| :---------- | :---------- | :--- | +| `ttft` | Time to First Token (client arrival → first token) | ms | +| `s_ttft` | Server TTFT (inference start → first token) | ms | +| `tpot` | Time per Output Token (excluding first token) | ms | +| `s_itl` | Infer Inter-token Latency | ms | +| `e2el` | End-to-end Latency (client arrival → last token) | ms | +| `s_e2el` | Server E2EL (inference start → last token) | ms | +| `s_decode` | Decode speed (excluding first token) | tok/s | +| `input_len` | Prefix cache hit token count ("Cached Tokens") | tokens | +| `s_input_len` | Infer input length (total prompt tokens) | tokens | +| `output_len` | Output token length per request | tokens | + +In addition, the following throughput metrics are always computed (not user-selectable) when there are 2+ records: + +| Metric | Description | Unit | +| :----- | :---------- | :--- | +| `request_throughput` | Request throughput | req/s | +| `output_throughput` | Output token throughput | tok/s | +| `total_throughput` | Total token throughput (input + output) | tok/s | + +### Window Modes + +**Sliding Window** (`"sliding"`, default): + +The window keeps the most recent N records. When a new record arrives and the window is full, the oldest record is automatically dropped. Each output line reflects the statistics of the latest N requests. + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 64, "window_mode": "sliding"}' +``` + +**Tumbling Window** (`"tumbling"`): + +The window accumulates records up to N, then clears and starts fresh. Each output line still reflects the current window's accumulated statistics, but the window resets at every boundary. This is useful for RL training scenarios where each step has a fixed batch size and you want per-step independent analysis. + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 64, "window_mode": "tumbling"}' +``` + +**No Window** (`window_size: 0`): + +All completed requests are accumulated. Statistics reflect the entire lifetime of the service. + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 0}' +``` + +### Output + +Results are written to `{FD_LOG_DIR}/benchmark_metrics.jsonl` (default: `./log/benchmark_metrics.jsonl`). Each line is a JSON object representing the window statistics at the time of a request completion. + +Example output line: + +```json +{ + "timestamp": "2026-05-14T10:30:05.123", + "window_size": 64, + "window_mode": "sliding", + "completed": 64, + "total_input_tokens": 8192, + "total_output_tokens": 16384, + "request_throughput": 5.2, + "output_throughput": 1250.0, + "total_throughput": 2500.0, + "ttft_ms": {"mean": 45.0, "median": 42.1, "p50": 42.1, "p90": 68.5, "p95": 82.3, "p99": 120.5}, + "s_decode": {"mean": 67.3, "median": 67.5, "p50": 67.5, "p90": 70.1, "p95": 71.2, "p99": 73.0} +} +``` diff --git a/docs/features/global_cache_pooling.md b/docs/features/global_cache_pooling.md index 2218e788cf3..3c8e18301b6 100644 --- a/docs/features/global_cache_pooling.md +++ b/docs/features/global_cache_pooling.md @@ -90,7 +90,7 @@ Create a `mooncake_config.json` file: "metadata_server": "http://0.0.0.0:15002/metadata", "master_server_addr": "0.0.0.0:15001", "global_segment_size": 1000000000, - "local_buffer_size": 134217728, + "local_buffer_size": 1048576, "protocol": "rdma", "rdma_devices": "" } diff --git a/docs/get_started/installation/nvidia_gpu.md b/docs/get_started/installation/nvidia_gpu.md index c59467175de..5a1f1ae2156 100644 --- a/docs/get_started/installation/nvidia_gpu.md +++ b/docs/get_started/installation/nvidia_gpu.md @@ -12,10 +12,13 @@ The following installation methods are available when your environment meets the ## 1. Pre-built Docker Installation (Recommended) -**Notice**: The pre-built image only supports SM80/90 GPU(e.g. H800/A800),if you are deploying on SM86/89GPU(L40/4090/L20), please reinstall ```fastdeploy-gpu``` after you create the container. +**Notice**: The pre-built image supports SM 80/86/89/90 architecture GPUs (e.g. A800/H800/L20/L40/4090), and requires Python 3.10. ```shell -docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.4.0 +# CUDA 12.6 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.5.0 +# CUDA 12.9 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.9:2.5.0 ``` ## 2. Pre-built Pip Installation @@ -23,35 +26,38 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12 First install paddlepaddle-gpu. For detailed instructions, refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html) ```shell # Install stable release -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# CUDA 12.6 +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# CUDA 12.9 +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ # Install latest Nightly build +# CUDA 12.6 python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ +# CUDA 12.9 +python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ ``` -Then install fastdeploy. **Do not install from PyPI**. Use the following methods instead: +Then install fastdeploy. **Do not install from PyPI**. Use the following methods instead (supports SM80/86/89/90 GPU architectures). -For SM80/90 architecture GPUs(e.g A30/A100/H100/): -``` -# Install stable release -python -m pip install fastdeploy-gpu==2.4.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Install latest Nightly build -python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple -``` - -For SM86/89 architecture GPUs(e.g A10/4090/L20/L40): -``` -# Install stable release -python -m pip install fastdeploy-gpu==2.4.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-86_89/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Install latest Nightly build +**Note**: Stable FastDeploy release pairs with stable PaddlePaddle; Nightly Build FastDeploy pairs with Nightly Build PaddlePaddle. The `--extra-index-url` is only used for downloading fastdeploy-gpu's dependencies; fastdeploy-gpu itself must be installed from the Paddle source specified by `-i`. +```shell +# Install stable release FastDeploy +# CUDA 12.6 +python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# CUDA 12.9 +python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + +# Install Nightly Build FastDeploy +# CUDA 12.6 python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# CUDA 12.9 +python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple ``` ## 3. Build from Source Using Docker -- Note: ```dockerfiles/Dockerfile.gpu``` by default supports SM 80/90 architectures. To support other architectures, modify ```bash build.sh 1 python false [80,90]``` in the Dockerfile. It's recommended to specify no more than 2 architectures. +> Note: `dockerfiles/Dockerfile.gpu` currently supports CUDA 12.6 only, targeting SM 80/86/89/90 architectures, and requires Python 3.10. To support other architectures, modify ```bash build.sh 1 python false [80,90]``` in the Dockerfile. It's recommended to specify no more than 2 architectures. ```shell git clone https://github.com/PaddlePaddle/FastDeploy @@ -64,10 +70,8 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu . First install paddlepaddle-gpu. For detailed instructions, refer to [PaddlePaddle Installation](https://www.paddlepaddle.org.cn/en/install/quick?docurl=/documentation/docs/en/develop/install/pip/linux-pip_en.html) ```shell -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ ``` - -Then clone the source code and build: ```shell git clone https://github.com/PaddlePaddle/FastDeploy cd FastDeploy @@ -82,8 +86,7 @@ The built packages will be in the ```FastDeploy/dist``` directory. ## 5. Precompiled Operator Wheel Packages -FastDeploy provides precompiled GPU operator wheel packages for quick setup without building the entire source code. -This method currently supports **SM80/90 architecture (e.g., A100/H100)** and **CUDA 12.6** environments only. +FastDeploy provides precompiled GPU operator wheel packages for quick setup without building the entire source code. This method currently supports **SM80/90 architecture (e.g., A100/H100)** **CUDA 12.6** and **Python 3.10** environments only. > By default, `build.sh` compiles all custom operators from source.To use the precompiled package, enable it with the `FD_USE_PRECOMPILED` parameter. > If the precompiled package cannot be downloaded or does not match the current environment, the system will automatically fall back to `4. Build Wheel from Source`. @@ -92,7 +95,7 @@ First, install paddlepaddle-gpu. For detailed instructions, please refer to the [PaddlePaddle Installation Guide](https://www.paddlepaddle.org.cn/). ```shell -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ ``` Then, clone the FastDeploy repository and build using the precompiled operator wheels: @@ -112,7 +115,7 @@ cd FastDeploy bash build.sh 1 python false [90] 1 # Use precompiled wheel from a specific commit -bash build.sh 1 python false [90] 1 8a9e7b53af4a98583cab65e4b44e3265a93e56d2 +bash build.sh 1 python false [90] 1 d693d4be1448d414097882386fdc24c8bec2a63a ``` The downloaded wheel packages will be stored in the `FastDeploy/pre_wheel` directory. @@ -121,9 +124,9 @@ After the build completes, the operator binaries can be found in `FastDeploy/fas > **Notes:** > > - This mode prioritizes downloading precompiled GPU operator wheels to reduce build time. -> - Currently supports **GPU, SM80/90, CUDA 12.6** only. +> - Currently supports **GPU, SM80/90, CUDA 12.6, Python3.10** only. > - For custom architectures or modified operator logic, please use **source compilation (Section 4)**. -> - You can check whether the precompiled wheel for a specific commit has been successfully built on the [FastDeploy CI Build Status Page](https://github.com/PaddlePaddle/FastDeploy/actions/workflows/ci_image_update.yml). +> - You can check whether the precompiled wheel for a specific commit has been successfully built on the [FastDeploy CI Build Status Page](https://github.com/PaddlePaddle/FastDeploy/actions/workflows/ce_job.yml). ## Environment Verification diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index 2b447476020..c9dba035339 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -577,4 +577,4 @@ DeltaFunctionCall: - `/v1/pause` - Pause generation (causes denial of service). Inflight requests are aborted and cache is reset. - `/v1/resume` - Resume generation. - `/v1/is_paused` - Check if generation is paused. -- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). Returns the list of aborted requests with their generated token counts. +- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). diff --git a/docs/online_serving/metrics.md b/docs/online_serving/metrics.md index 8e2dd286888..4f30cf3a222 100644 --- a/docs/online_serving/metrics.md +++ b/docs/online_serving/metrics.md @@ -46,3 +46,81 @@ After FastDeploy is launched, it supports continuous monitoring of the FastDeplo - Access URL: `http://localhost:8000/metrics` - Metric Type: Prometheus format + +## Trace Events + +FastDeploy outputs structured trace events to `trace.log` at key stages of request processing, useful for diagnosing per-request latency bottlenecks. Each trace log entry contains fields such as `timestamp` (milliseconds), `request_id`, `event`, and `stage`. + +### Common Events (Mixed / All Instances) + +| Stage | Event | Description | +| :---: | --- | --- | +| PREPROCESSING | `PREPROCESSING_START` | API Server begins preprocessing the request | +| PREPROCESSING | `PREPROCESSING_END` | Engine receives the request, preprocessing complete | +| SCHEDULE | `REQUEST_SCHEDULE_START` | Request enters the scheduling flow | +| SCHEDULE | `REQUEST_QUEUE_START` | Request enters the scheduling queue | +| SCHEDULE | `REQUEST_QUEUE_END` | Request dequeued from the scheduling queue | +| SCHEDULE | `RESOURCE_ALLOCATE_START` | Resource allocation begins for the request | +| SCHEDULE | `PREPARE_PREFIX_CACHE_START` | Prefix cache block matching begins | +| SCHEDULE | `PREPARE_PREFIX_CACHE_END` | Prefix cache block matching complete | +| SCHEDULE | `RESOURCE_ALLOCATE_END` | Resource allocation complete | +| SCHEDULE | `REQUEST_SCHEDULE_END` | Scheduling flow complete | +| PREFILL | `INFERENCE_START` | Request sent to GPU for inference | +| PREFILL | `FIRST_TOKEN_GENERATED` | First token generated | +| DECODE | `DECODE_START` | Enters Decode phase | +| DECODE | `INFERENCE_END` | Inference complete (all tokens generated) | +| DECODE | `PREEMPTED` | Request preempted | +| DECODE | `RESCHEDULED_INFERENCE_START` | Preempted request rescheduled for execution | +| POSTPROCESSING | `WRITE_CACHE_TO_STORAGE_START` | Begins writing KV Cache to external storage | +| POSTPROCESSING | `WRITE_CACHE_TO_STORAGE_END` | KV Cache written to external storage | +| POSTPROCESSING | `POSTPROCESSING_START` | Post-processing begins | +| POSTPROCESSING | `POSTPROCESSING_END` | Post-processing complete, response sent | + +### PD Disaggregation — Prefill (P) Instance Events + +| Stage | Event | Description | +| :---: | --- | --- | +| SCHEDULE | `ASK_DECODE_RESOURCE_START` | P begins requesting resources from D (sends ZMQ request) | +| SCHEDULE | `ASK_DECODE_RESOURCE_END` | P receives resource allocation confirmation from D (with dest_block_ids) | +| PREFILL | `PREFILL_INFERENCE_END` | P instance Prefill inference complete | +| POSTPROCESSING | `CHECK_CACHE_TRANSFER_START` | P begins waiting for KV Cache transfer to complete | +| POSTPROCESSING | `CHECK_CACHE_TRANSFER_END` | KV Cache transfer confirmed, ready to send first token to D | + +### PD Disaggregation — Decode (D) Instance Events + +| Stage | Event | Description | +| :---: | --- | --- | +| DECODE | `DECODE_PROCESS_PREALLOCATE_REQUEST_START` | D begins processing resource allocation request from P | +| DECODE | `DECODE_PROCESS_PREALLOCATE_REQUEST_END` | D completes resource allocation and returns dest_block_ids to P | +| DECODE | `DECODE_PROCESS_PREFILLED_REQUEST_START` | D receives first token from P, begins processing Prefilled request | +| DECODE | `DECODE_PROCESS_PREFILLED_REQUEST_END` | D adds Prefilled request to running queue | +| DECODE | `DECODE_INFERENCE_END` | D instance Decode inference complete | + +### Request Lifecycle Sequence + +**Mixed mode** (single instance, full inference): +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ RESOURCE_ALLOCATE_START → RESOURCE_ALLOCATE_END → INFERENCE_START +→ FIRST_TOKEN_GENERATED → DECODE_START → INFERENCE_END +→ POSTPROCESSING_START → POSTPROCESSING_END +``` + +**PD Disaggregation — Prefill (P) Instance**: +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ ASK_DECODE_RESOURCE_START → ASK_DECODE_RESOURCE_END +→ RESOURCE_ALLOCATE_START → RESOURCE_ALLOCATE_END +→ INFERENCE_START → PREFILL_INFERENCE_END +→ CHECK_CACHE_TRANSFER_START → CHECK_CACHE_TRANSFER_END → [send first token to D] +``` + +**PD Disaggregation — Decode (D) Instance**: +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ DECODE_PROCESS_PREALLOCATE_REQUEST_START → DECODE_PROCESS_PREALLOCATE_REQUEST_END +→ [wait for P to complete prefill and transfer KV Cache] +→ DECODE_PROCESS_PREFILLED_REQUEST_START → DECODE_PROCESS_PREFILLED_REQUEST_END +→ INFERENCE_START → DECODE_INFERENCE_END +→ POSTPROCESSING_START → POSTPROCESSING_END +``` diff --git a/docs/usage/environment_variables.md b/docs/usage/environment_variables.md index 692ad8cd023..e54ec8f8798 100644 --- a/docs/usage/environment_variables.md +++ b/docs/usage/environment_variables.md @@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # Whether to enable the decode caches requests for preallocating resource "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # Batched token timeout in EP + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), + # Max pre-fetch requests number in PD "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), diff --git a/docs/zh/benchmark.md b/docs/zh/benchmark.md index e4a58d93b1e..e0c55c63ef3 100644 --- a/docs/zh/benchmark.md +++ b/docs/zh/benchmark.md @@ -40,3 +40,106 @@ python benchmark_serving.py \ --max-concurrency 1 \ --save-result ``` + +## 进程内性能监控(Benchmark Metrics Logger) + +FastDeploy 提供了内置的进程内性能监控模块,在推理进程内部运行,复用已有的请求时间戳数据,每个请求完成时计算滚动统计并写入 JSONL 文件,可用于实时监控和事后分析。 + +### 启用方式 + +在服务启动命令中添加 `--benchmark-metrics-config` 参数,传入 JSON 配置字符串: + +```bash +python -m fastdeploy.entrypoints.openai.api_server \ + --model baidu/ERNIE-4.5-0.3B-Base-Paddle \ + --benchmark-metrics-config '{"enable": true}' +``` + +### 配置参数 + +| 参数 | 类型 | 默认值 | 说明 | +| :--- | :--- | :----- | :--- | +| `enable` | bool | `false` | 是否启用性能监控。必须设置为 `true` 才会激活。 | +| `window_size` | int | `0` | 统计窗口大小。`0` = 累计模式(统计所有请求);`>0` = 统计最近 N 个请求。 | +| `window_mode` | str | `"sliding"` | 窗口聚合模式。`"sliding"` = 滑动窗口(保持最近 N 条,旧记录自动淘汰);`"tumbling"` = 翻滚窗口(满 N 条后清空重新累积)。 | +| `percentiles` | str | `"50,90,95,99"` | 要计算的分位值,逗号分隔。 | +| `metrics` | str | `"all"` | 要统计的指标子集,逗号分隔,或 `"all"` 表示全部指标。 | + +### 可用指标 + +指标与 `benchmark_serving.py --percentile-metrics` 对齐: + +| 指标名称 | 说明 | 单位 | +| :------- | :--- | :--- | +| `ttft` | 首 Token 时延(客户端到达 → 首 Token) | ms | +| `s_ttft` | 服务端首 Token 时延(推理开始 → 首 Token) | ms | +| `tpot` | 每 Token 输出时延(不含首 Token) | ms | +| `s_itl` | 推理 Token 间时延 | ms | +| `e2el` | 端到端时延(客户端到达 → 最后一个 Token) | ms | +| `s_e2el` | 服务端端到端时延(推理开始 → 最后一个 Token) | ms | +| `s_decode` | 解码速度(不含首 Token) | tok/s | +| `input_len` | 前缀缓存命中 Token 数("Cached Tokens") | tokens | +| `s_input_len` | 推理输入长度(总 prompt token 数) | tokens | +| `output_len` | 输出 Token 长度 | tokens | + +此外,以下吞吐量指标在有 2 个以上请求完成时自动计算(不受 `metrics` 参数控制): + +| 指标 | 说明 | 单位 | +| :--- | :--- | :--- | +| `request_throughput` | 请求吞吐量 | req/s | +| `output_throughput` | 输出 Token 吞吐量 | tok/s | +| `total_throughput` | 总 Token 吞吐量(输入 + 输出) | tok/s | + +### 窗口模式 + +**滑动窗口**(`"sliding"`,默认): + +窗口始终保持最近 N 条记录。当新记录到达且窗口已满时,最旧的记录自动淘汰。每行输出反映最近 N 个请求的统计值。 + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 64, "window_mode": "sliding"}' +``` + +**翻滚窗口**(`"tumbling"`): + +窗口累积到 N 条后清空重新开始。每行输出反映当前窗口已累积请求的统计值,窗口在边界处重置。适用于 RL 训练场景,每个 step 有固定 batch size,需要逐 step 独立分析。 + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 64, "window_mode": "tumbling"}' +``` + +**无窗口**(`window_size: 0`): + +所有已完成请求持续累积,统计值反映服务启动以来的全量数据。 + +```bash +--benchmark-metrics-config '{"enable": true, "window_size": 0}' +``` + +### 输出说明 + +结果写入 `{FD_LOG_DIR}/benchmark_metrics.jsonl`(默认路径:`./log/benchmark_metrics.jsonl`)。每行为一个 JSON 对象,表示某个请求完成时刻窗口内的统计快照。 + +输出示例: + +```json +{ + "timestamp": "2026-05-14T10:30:05.123", + "window_size": 64, + "window_mode": "sliding", + "completed": 64, + "total_input_tokens": 8192, + "total_output_tokens": 16384, + "request_throughput": 5.2, + "output_throughput": 1250.0, + "total_throughput": 2500.0, + "ttft_ms": {"mean": 45.0, "median": 42.1, "p50": 42.1, "p90": 68.5, "p95": 82.3, "p99": 120.5}, + "s_decode": {"mean": 67.3, "median": 67.5, "p50": 67.5, "p90": 70.1, "p95": 71.2, "p99": 73.0} +} +``` + +读取最后一行即可获取当前最新的性能快照: + +```bash +tail -1 log/benchmark_metrics.jsonl | python -m json.tool +``` diff --git a/docs/zh/features/global_cache_pooling.md b/docs/zh/features/global_cache_pooling.md index 292e764ac80..b0cf985f3a3 100644 --- a/docs/zh/features/global_cache_pooling.md +++ b/docs/zh/features/global_cache_pooling.md @@ -90,7 +90,7 @@ pip install ./dist/fastdeploy*.whl "metadata_server": "http://0.0.0.0:15002/metadata", "master_server_addr": "0.0.0.0:15001", "global_segment_size": 1000000000, - "local_buffer_size": 134217728, + "local_buffer_size": 1048576, "protocol": "rdma", "rdma_devices": "" } diff --git a/docs/zh/get_started/installation/nvidia_gpu.md b/docs/zh/get_started/installation/nvidia_gpu.md index 8e989db5368..dd266b6c7eb 100644 --- a/docs/zh/get_started/installation/nvidia_gpu.md +++ b/docs/zh/get_started/installation/nvidia_gpu.md @@ -14,10 +14,13 @@ ## 1. 预编译Docker安装(推荐) -**注意**: 如下镜像仅支持SM 80/90架构GPU(A800/H800等),如果你是在L20/L40/4090等SM 86/89架构的GPU上部署,请在创建容器后,卸载```fastdeploy-gpu```再重新安装如下文档指定支持86/89架构的`fastdeploy-gpu`包。 +**注意**: 预编译镜像支持 80/86/89/90 架构的GPU硬件 (如 A800/H800/L20/L40/4090) 且仅支持 Python 3.10。 ``` shell -docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.4.0 +# CUDA 12.6 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.6:2.5.0 +# CUDA 12.9 +docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12.9:2.5.0 ``` ## 2. 预编译Pip安装 @@ -26,37 +29,38 @@ docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/fastdeploy-cuda-12 ``` shell # Install stable release -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# CUDA 12.6 +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +# CUDA 12.9 +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ # Install latest Nightly build +# CUDA 12.6 python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ +# CUDA 12.9 +python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ ``` -再安装 fastdeploy,**注意不要通过pypi源安装**,需要通过如下方式安装 +再安装 fastdeploy,**注意不要通过pypi源安装**,需要通过如下方式安装(目前支持80/86/89/90四个架构GPU) -如你的 GPU 是 SM80/90 架构(A100/H100等),按如下方式安装 - -``` -# 安装稳定版本fastdeploy -python -m pip install fastdeploy-gpu==2.4.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-80_90/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# 安装Nightly Build的最新版本fastdeploy -python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple -``` - -如你的 GPU 是 SM86/89 架构(4090/L20/L40等),按如下方式安装 - -``` -# 安装稳定版本fastdeploy -python -m pip install fastdeploy-gpu==2.4.0 -i https://www.paddlepaddle.org.cn/packages/stable/fastdeploy-gpu-86_89/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# 安装Nightly Build的最新版本fastdeploy +**注意**: 稳定版本的FastDeploy搭配稳定版本的PaddlePaddle; 而Nightly Build的FastDeploy则对应Nightly Build的PaddlePaddle。其中 `--extra-index-url` 仅用于安装 fastdeploy-gpu 所需的依赖包,fastdeploy-gpu 本身必须从 `-i` 指定的 Paddle 源安装。 +```shell +# 安装稳定版本FastDeploy +# CUDA 12.6 +python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# CUDA 12.9 +python -m pip install fastdeploy-gpu==2.5.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu129/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + +# 安装Nightly Build版本FastDeploy +# CUDA 12.6 python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu126/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# CUDA 12.9 +python -m pip install fastdeploy-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple ``` ## 3. 镜像自行构建 -> 注意 ```dockerfiles/Dockerfile.gpu``` 默认编译的架构支持SM 80/90,如若需要支持其它架构,需自行修改Dockerfile中的 ```bash build.sh 1 python false [80,90]```,建议不超过2个架构。 +> 注意 ```dockerfiles/Dockerfile.gpu``` 默认编译产物仅支持 SM 80/86/89/90 架构,基于 CUDA 12.6 环境构建,且仅支持 Python 3.10,如若需要支持其它架构,需自行修改Dockerfile中的 ```bash build.sh 1 python false [80,90]```,建议不超过2个架构。 ``` git clone https://github.com/PaddlePaddle/FastDeploy @@ -70,7 +74,7 @@ docker build -f dockerfiles/Dockerfile.gpu -t fastdeploy:gpu . 首先安装 paddlepaddle-gpu,详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/) ``` shell -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ ``` 接着克隆源代码,编译安装 @@ -90,7 +94,7 @@ bash build.sh 1 python false [80,90] ## 5. 算子预编译 Wheel 包 -FastDeploy 提供了 GPU 算子预编译版 Wheel 包,可在无需完整源码编译的情况下快速构建。该方式当前仅支持 **SM80/90 架构(A100/H100等)** 和 **CUDA 12.6** 环境。 +FastDeploy 提供了 GPU 算子预编译版 Wheel 包,可在无需完整源码编译的情况下快速构建。该方式当前仅支持 **SM80/90 架构(A100/H100等)** **CUDA 12.6** 和 **Python 3.10** 环境。 >默认情况下,`build.sh` 会从源码编译;若希望使用预编译包,可使用`FD_USE_PRECOMPILED` 参数; >若预编译包下载失败或与环境不匹配,系统会自动回退至 `4. wheel 包源码编译` 模式。 @@ -98,7 +102,7 @@ FastDeploy 提供了 GPU 算子预编译版 Wheel 包,可在无需完整源码 首先安装 paddlepaddle-gpu,详细安装方式参考 [PaddlePaddle安装](https://www.paddlepaddle.org.cn/) ``` shell -python -m pip install paddlepaddle-gpu==3.3.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ +python -m pip install paddlepaddle-gpu==3.3.1 -i https://www.paddlepaddle.org.cn/packages/stable/cu126/ ``` 接着克隆源代码,拉取 whl 包并安装 @@ -118,7 +122,7 @@ cd FastDeploy bash build.sh 1 python false [90] 1 # 从指定 commitID 获取对应预编译算子 -bash build.sh 1 python false [90] 1 8a9e7b53af4a98583cab65e4b44e3265a93e56d2 +bash build.sh 1 python false [90] 1 d693d4be1448d414097882386fdc24c8bec2a63a ``` 下载的 whl 包在 `FastDeploy/pre_wheel`目录下。 @@ -127,7 +131,7 @@ bash build.sh 1 python false [90] 1 8a9e7b53af4a98583cab65e4b44e3265a93e56d2 > **说明:** > - 该模式会优先下载预编译的 GPU 算子 whl 包,减少编译时间; -> - 目前仅支持 **GPU, SM80/90 架构, CUDA 12.6**; +> - 目前仅支持 **GPU, SM80/90 架构, CUDA 12.6, Python3.10**; > - 若希望自定义架构或修改算子逻辑,请使用 **源码编译方式(第4节)**。 > - 您可以在 FastDeploy CI 构建状态页面查看对应 commit 的预编译 whl 是否已构建成功。 diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index 21f16d06e32..0264c928bd5 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -563,4 +563,4 @@ DeltaFunctionCall: /v1/pause - 暂停推理生成(会导致服务拒绝推理请求)。正在进行中的请求会被中止,缓存会被重置。 /v1/resume - 恢复推理生成。 /v1/is_paused - 检查推理生成是否已暂停。 -/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。返回已中断请求列表及其已生成的 token 数。 +/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。 diff --git a/docs/zh/online_serving/metrics.md b/docs/zh/online_serving/metrics.md index 630f68e2ff8..20da957bcf2 100644 --- a/docs/zh/online_serving/metrics.md +++ b/docs/zh/online_serving/metrics.md @@ -46,3 +46,81 @@ - 访问地址:`http://localhost:8000/metrics` - 指标类型:Prometheus 格式 + +## Trace 事件 + +FastDeploy 在请求处理的关键阶段输出结构化 trace 事件到 `trace.log`,用于定位请求级别的延迟瓶颈。每条 trace 日志包含 `timestamp`(毫秒)、`request_id`、`event`、`stage` 等字段。 + +### 通用事件(Mixed / 所有实例) + +| 阶段 | 事件 | 说明 | +| :---: | --- | --- | +| PREPROCESSING | `PREPROCESSING_START` | API Server 开始预处理请求 | +| PREPROCESSING | `PREPROCESSING_END` | Engine 收到请求,预处理完成 | +| SCHEDULE | `REQUEST_SCHEDULE_START` | 请求进入调度流程 | +| SCHEDULE | `REQUEST_QUEUE_START` | 请求进入调度队列等待 | +| SCHEDULE | `REQUEST_QUEUE_END` | 请求从调度队列取出 | +| SCHEDULE | `RESOURCE_ALLOCATE_START` | 开始为请求分配资源 | +| SCHEDULE | `PREPARE_PREFIX_CACHE_START` | 开始匹配前缀缓存块 | +| SCHEDULE | `PREPARE_PREFIX_CACHE_END` | 前缀缓存块匹配完成 | +| SCHEDULE | `RESOURCE_ALLOCATE_END` | 资源分配完成 | +| SCHEDULE | `REQUEST_SCHEDULE_END` | 调度流程结束 | +| PREFILL | `INFERENCE_START` | 请求送入 GPU 执行推理 | +| PREFILL | `FIRST_TOKEN_GENERATED` | 首 token 生成 | +| DECODE | `DECODE_START` | 进入 Decode 阶段 | +| DECODE | `INFERENCE_END` | 推理完成(所有 token 生成完毕) | +| DECODE | `PREEMPTED` | 请求被抢占 | +| DECODE | `RESCHEDULED_INFERENCE_START` | 被抢占的请求重新调度执行 | +| POSTPROCESSING | `WRITE_CACHE_TO_STORAGE_START` | 开始将 KV Cache 写入外部存储 | +| POSTPROCESSING | `WRITE_CACHE_TO_STORAGE_END` | KV Cache 写入外部存储完成 | +| POSTPROCESSING | `POSTPROCESSING_START` | 开始后处理 | +| POSTPROCESSING | `POSTPROCESSING_END` | 后处理完成,响应发送完毕 | + +### PD 分离 — Prefill (P) 实例专属事件 + +| 阶段 | 事件 | 说明 | +| :---: | --- | --- | +| SCHEDULE | `ASK_DECODE_RESOURCE_START` | P 开始向 D 申请资源(发送 ZMQ 请求) | +| SCHEDULE | `ASK_DECODE_RESOURCE_END` | P 收到 D 的资源分配确认(含 dest_block_ids) | +| PREFILL | `PREFILL_INFERENCE_END` | P 实例 Prefill 推理完成 | +| POSTPROCESSING | `CHECK_CACHE_TRANSFER_START` | P 开始等待 KV Cache 传输完成 | +| POSTPROCESSING | `CHECK_CACHE_TRANSFER_END` | KV Cache 传输完成确认,准备发送 first token 到 D | + +### PD 分离 — Decode (D) 实例专属事件 + +| 阶段 | 事件 | 说明 | +| :---: | --- | --- | +| DECODE | `DECODE_PROCESS_PREALLOCATE_REQUEST_START` | D 开始处理 P 发来的资源分配请求 | +| DECODE | `DECODE_PROCESS_PREALLOCATE_REQUEST_END` | D 完成资源分配并返回 dest_block_ids 给 P | +| DECODE | `DECODE_PROCESS_PREFILLED_REQUEST_START` | D 收到 P 的 first token,开始处理 Prefilled 请求 | +| DECODE | `DECODE_PROCESS_PREFILLED_REQUEST_END` | D 将 Prefilled 请求加入 running queue | +| DECODE | `DECODE_INFERENCE_END` | D 实例 Decode 推理完成 | + +### 请求生命周期时序图 + +**Mixed 模式**(单实例完整推理): +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ RESOURCE_ALLOCATE_START → RESOURCE_ALLOCATE_END → INFERENCE_START +→ FIRST_TOKEN_GENERATED → DECODE_START → INFERENCE_END +→ POSTPROCESSING_START → POSTPROCESSING_END +``` + +**PD 分离 — Prefill (P) 实例**: +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ ASK_DECODE_RESOURCE_START → ASK_DECODE_RESOURCE_END +→ RESOURCE_ALLOCATE_START → RESOURCE_ALLOCATE_END +→ INFERENCE_START → PREFILL_INFERENCE_END +→ CHECK_CACHE_TRANSFER_START → CHECK_CACHE_TRANSFER_END → [发送 first token 到 D] +``` + +**PD 分离 — Decode (D) 实例**: +``` +PREPROCESSING_START → PREPROCESSING_END → REQUEST_QUEUE_START → REQUEST_QUEUE_END +→ DECODE_PROCESS_PREALLOCATE_REQUEST_START → DECODE_PROCESS_PREALLOCATE_REQUEST_END +→ [等待 P 完成 prefill 并传输 KV Cache] +→ DECODE_PROCESS_PREFILLED_REQUEST_START → DECODE_PROCESS_PREFILLED_REQUEST_END +→ INFERENCE_START → DECODE_INFERENCE_END +→ POSTPROCESSING_START → POSTPROCESSING_END +``` diff --git a/docs/zh/usage/environment_variables.md b/docs/zh/usage/environment_variables.md index 0a4cfd389db..ab625bd4d2c 100644 --- a/docs/zh/usage/environment_variables.md +++ b/docs/zh/usage/environment_variables.md @@ -162,6 +162,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # 是否启用 decode 缓存请求以预分配资源 "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # EP 中批处理 token 的超时时间 + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), + # PD 中最大预取请求数量 "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), diff --git a/examples/cache_storage/run_03b_pd_storage.sh b/examples/cache_storage/run_03b_pd_storage.sh index 5577a0ebf27..c940fe9a8ef 100644 --- a/examples/cache_storage/run_03b_pd_storage.sh +++ b/examples/cache_storage/run_03b_pd_storage.sh @@ -18,7 +18,7 @@ metadata_port=15002 export MOONCAKE_MASTER_SERVER_ADDR="${master_ip}:${master_port}" export MOONCAKE_METADATA_SERVER="http://${master_ip}:${metadata_port}/metadata" -export MOONCAKE_GLOBAL_SEGMENT_SIZE="50000000000" +export MOONCAKE_GLOBAL_SEGMENT_SIZE="50000000000" # 50GB # export MOONCAKE_PROTOCOL="tcp" export MOONCAKE_PROTOCOL="rdma" # export MOONCAKE_RDMA_DEVICES="mlx5_0" diff --git a/examples/splitwise/start_v0_tp1.sh b/examples/splitwise/start_v0_tp1.sh deleted file mode 100644 index 40c20301138..00000000000 --- a/examples/splitwise/start_v0_tp1.sh +++ /dev/null @@ -1,113 +0,0 @@ -#!/bin/bash -set -e - -# Test splitwise deployment -# There are two methods for splitwise deployment: -# v0: using splitwise_scheduler or dp_scheduler (deprecated) -# v1: using local_scheduler + router - -# prepare environment -export MODEL_NAME="PaddlePaddle/ERNIE-4.5-0.3B-Paddle" -export FD_DEBUG=1 -export ENABLE_V1_KVCACHE_SCHEDULER=1 -export KVCACHE_GDRCOPY_FLUSH_ENABLE=1 - -SCRIPT_PATH=$(readlink -f "$0") -SCRIPT_DIR=$(dirname "$SCRIPT_PATH") -export $(bash ${SCRIPT_DIR}/../../scripts/get_rdma_nics.sh gpu) -echo "KVCACHE_RDMA_NICS:${KVCACHE_RDMA_NICS}" -if [ -z "${KVCACHE_RDMA_NICS}" ]; then - echo "KVCACHE_RDMA_NICS is empty, please check the output of get_rdma_nics.sh" - exit 1 -fi - -unset http_proxy && unset https_proxy -source ${SCRIPT_DIR}/utils.sh - -P_PORT=52400 -D_PORT=52500 -REDIS_PORT="${REDIS_PORT:-6379}" -LOG_DATE=$(date +%Y%m%d_%H%M%S) - -ports=( - $P_PORT $((P_PORT + 1)) $((P_PORT + 2)) $((P_PORT + 3)) $((P_PORT + 4)) $((P_PORT + 5)) - $D_PORT $((D_PORT + 1)) $((D_PORT + 2)) $((D_PORT + 3)) $((D_PORT + 4)) $((D_PORT + 5)) - $REDIS_PORT -) -check_ports "${ports[@]}" || { - echo "❌ Some ports are in use. Please release them." - exit 1 -} - -# start redis -if ! redis-cli -p ${REDIS_PORT} ping &>/dev/null; then - echo "Redis is not running. Starting redis-server..." - redis-server --daemonize yes --port ${REDIS_PORT} - sleep 1 -else - echo "Redis is already running." -fi -sleep 1 - -# start prefill -export CUDA_VISIBLE_DEVICES=0 -export FD_LOG_DIR="log/$LOG_DATE/prefill" -rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR} - -nohup python -m fastdeploy.entrypoints.openai.api_server \ - --model ${MODEL_NAME} \ - --port ${P_PORT} \ - --metrics-port $((P_PORT + 1)) \ - --engine-worker-queue-port $((P_PORT + 2)) \ - --cache-queue-port $((P_PORT + 3)) \ - --max-model-len 32768 \ - --num-gpu-blocks-override 1000 \ - --splitwise-role "prefill" \ - --cache-transfer-protocol "rdma" \ - --rdma-comm-ports $((P_PORT + 4)) \ - --pd-comm-port $((P_PORT + 5)) \ - --scheduler-name "splitwise" \ - --scheduler-host "127.0.0.1" \ - --scheduler-port ${REDIS_PORT} \ - --scheduler-ttl 9000 \ - 2>&1 >${FD_LOG_DIR}/nohup & - -wait_for_health ${P_PORT} - -# start decode -export CUDA_VISIBLE_DEVICES=1 -export FD_LOG_DIR="log/$LOG_DATE/decode" -rm -rf ${FD_LOG_DIR} && mkdir -p ${FD_LOG_DIR} - -nohup python -m fastdeploy.entrypoints.openai.api_server \ - --model ${MODEL_NAME} \ - --port ${D_PORT} \ - --metrics-port $((D_PORT + 1)) \ - --engine-worker-queue-port $((D_PORT + 2)) \ - --cache-queue-port $((D_PORT + 3)) \ - --max-model-len 32768 \ - --splitwise-role "decode" \ - --cache-transfer-protocol "rdma" \ - --rdma-comm-ports $((D_PORT + 4)) \ - --pd-comm-port $((D_PORT + 5)) \ - --scheduler-name "splitwise" \ - --scheduler-host "127.0.0.1" \ - --scheduler-port ${REDIS_PORT} \ - --scheduler-ttl 9000 \ - 2>&1 >${FD_LOG_DIR}/nohup & - -wait_for_health ${D_PORT} - - -# send request -sleep 10 # make sure server is registered to router -echo "send request..." -curl -X POST "http://0.0.0.0:${D_PORT}/v1/chat/completions" \ --H "Content-Type: application/json" \ --d '{ - "messages": [ - {"role": "user", "content": "hello"} - ], - "max_tokens": 20, - "stream": false -}' diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 82911eccfa3..9fd48cec2ce 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -14,13 +14,35 @@ # limitations under the License. """ +from dataclasses import dataclass from enum import Enum +from typing import Any, Optional from fastdeploy.utils import get_logger logger = get_logger("prefix_cache_manager", "cache_manager.log") +@dataclass +class AuxBlockDataSpec: + """ + Describes a type of auxiliary data bound to KVCache blocks. + CacheTransferManager iterates registered specs during swap/storage + to perform corresponding data transfers. + """ + + name: str + num_layers: int + per_token_size: int = 0 + block_size: int = 0 + dtype: str = "uint8" + swap_buffer: Optional[Any] = None + enabled: bool = True + + def get_storage_key(self, key_prefix: str, block_hash: str, rank: int) -> str: + return f"prefix{key_prefix}_{block_hash}_{rank}_{self.name}" + + class CacheStatus(Enum): """ cache status enum class @@ -56,6 +78,7 @@ def __init__( cache_status=CacheStatus.GPU, is_persistent=False, persistent_shared_count=0, + aux_data_names=None, ): """ Args: @@ -89,6 +112,7 @@ def __init__( self.cache_status = cache_status self.is_persistent = is_persistent self.persistent_shared_count = persistent_shared_count + self.aux_data_names = aux_data_names or [] self.req_id_set = set() def __lt__(self, other): @@ -102,7 +126,7 @@ def __lt__(self, other): else: return self.depth > other.depth - def __str__(self): + def __str__(self) -> str: """ return node info """ diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index b934c3e74c7..33407b785e5 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -613,12 +613,18 @@ def __init__( ) self.gpu_id = gpu_id - self.cache_info = dict() + self.cache_info = dict() # {'request_id': cache_info_dict} self.rank_id = self.rank + local_data_parallel_id * self.nranks self.engine_cache_task_thread_lock = threading.Lock() - self.engine_cache_tasks = [dict() for _ in range(512)] - self.idx_cache_task_dict = {} - self.cache_prefilled_engine_ids_queue = queue.Queue() # keep batch slot index for each prefill step + self.engine_cache_tasks = [ + dict() for _ in range(512) + ] # {'layer_id': {'prefilled_layer_idx': xx, 'prefilled_block_num': xx}} + self.idx_cache_task_dict = {} # {'slot_idx': cache_info_dict} + self.pending_layer0_signals = {} + self.pending_layer0_signal_lock = threading.Lock() + self.cache_prefilled_engine_ids_queue = ( + queue.Queue() + ) # [(slot_idx1, prefilled_token_num1), (slot_idx2, prefilled_token_num2)] if splitwise_role == "prefill": consume_signals_thread = threading.Thread(target=self.consume_signals) consume_signals_thread.daemon = True @@ -638,7 +644,6 @@ def _add_cache_task_thread(self): while True: try: cache_info = self.engine_worker_queue.get_cache_info() - finished_add_cache_task_req_ids = [] if cache_info: logger.debug(f"Get cache info from engine worker queue, {cache_info}") self.engine_worker_queue.cache_info_barrier.wait() @@ -647,7 +652,6 @@ def _add_cache_task_thread(self): self.cache_info[info["request_id"]].update(info) current_info = self.cache_info[info["request_id"]] assert "dest_block_ids" in current_info and "src_block_ids" in current_info - finished_add_cache_task_req_ids.append(info["request_id"]) decode_cached_block_num = len(current_info["src_block_ids"]) - len( current_info["dest_block_ids"] ) @@ -659,17 +663,34 @@ def _add_cache_task_thread(self): current_info["sended_layer_id"] = -1 current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size current_info["status"] = "init" - logger.info(f"Get cache info from D: finish add cache task: {current_info}") + logger.info(f"Get cache info and finish add cache task: {current_info}") self.cache_info[info["request_id"]] = current_info - self.idx_cache_task_dict[current_info["current_id"]] = current_info + current_id = current_info["current_id"] + with self.engine_cache_task_thread_lock: + self.idx_cache_task_dict[current_id] = current_info + with self.pending_layer0_signal_lock: + recovered_signal = self.pending_layer0_signals.pop(current_id, None) + if recovered_signal is not None: + _, prefilled_token_num = recovered_signal + if prefilled_token_num <= current_info["need_prefill_tokens"]: + recovered_signal_batch = [recovered_signal] + logger.info( + "cache_task_register_recover_layer0_signal: " + f"current_id: {current_id}, " + f"recovered_signal_batch: {recovered_signal_batch}" + ) + self.cache_prefilled_engine_ids_queue.put(recovered_signal_batch) + else: + logger.info( + "cache_task_register_drop_layer0_signal: " + f"current_id: {current_id}, " + f"recovered_signal: {recovered_signal}, " + f"need_prefill_tokens: {current_info['need_prefill_tokens']}" + ) else: - logger.info(f"Get cache info from P: {info}") + logger.info(f"Get cache info: {info}") self.cache_info[info["request_id"]] = info - if finished_add_cache_task_req_ids: - logger.info(f"Put processed tasks into engine worker queue: {finished_add_cache_task_req_ids}") - self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids) - self.engine_worker_queue.finish_add_cache_task_barrier.wait() else: time.sleep(0.001) except Exception as e: @@ -684,13 +705,16 @@ def prefill_layerwise_send_cache_thread(self): try: batch_engine_signals = self.cache_prefilled_engine_ids_queue.get() self.engine_worker_queue.begin_send_cache_barrier.wait() + block_start_end_list = [] current_prefilled_token_num_list = [] for engine_index, current_step_prefilled_token_num in batch_engine_signals: + self._maybe_wait_for_cache_task(engine_index) assert ( engine_index in self.idx_cache_task_dict ), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}" block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"] + prefilled_token_num = current_step_prefilled_token_num if ( prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"] @@ -842,9 +866,12 @@ def prefill_layerwise_send_cache_thread(self): logger.info( f"Put successful cache writing task in engine worker queue, req_id: {task['request_id']}, status: {task['status']}" ) - self.engine_cache_tasks[task["current_id"]] = dict() + current_id = task["current_id"] + self.engine_cache_tasks[current_id] = dict() del self.cache_info[task["request_id"]] - del self.idx_cache_task_dict[task["current_id"]] + del self.idx_cache_task_dict[current_id] + with self.pending_layer0_signal_lock: + self.pending_layer0_signals.pop(current_id, None) break except Exception as e: logger.error(f"prefill layerwise send cache thread has exception: {e} {traceback.format_exc()!s}") @@ -856,32 +883,42 @@ def consume_signals(self): while True: try: get_output_kv_signal(kv_signal_data, self.rank_id, 1) # wait_flag - if not self.cache_info: - time.sleep(0.01) - continue - tasks_count = kv_signal_data[0] + has_cache_info = bool(self.cache_info) + tasks_count = kv_signal_data[0].item() if tasks_count == -1: continue + if not has_cache_info: + logger.debug("consume_signals get kv signal before cache info is ready") layer_id = kv_signal_data[1].item() if layer_id == self.num_layers - 1: logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}") - batch_engine_signals = [] + ready_engine_signals = [] + pending_engine_signals = [] # format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)] with self.engine_cache_task_thread_lock: for bi in range(tasks_count): engine_idx = kv_signal_data[3 * bi + 2].item() chuck_token_offset = kv_signal_data[3 * bi + 3].item() current_seq_len = kv_signal_data[3 * bi + 4].item() + prefilled_token_num = chuck_token_offset + current_seq_len self.engine_cache_tasks[engine_idx]["prefilled_layer_idx"] = layer_id - self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = ( - chuck_token_offset + current_seq_len - ) - batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len)) - if layer_id == 0: - logger.info( - f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue" - ) - self.cache_prefilled_engine_ids_queue.put(batch_engine_signals) + self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = prefilled_token_num + if layer_id == 0: + if engine_idx in self.idx_cache_task_dict: + ready_engine_signals.append((engine_idx, prefilled_token_num)) + else: + pending_engine_signals.append((engine_idx, prefilled_token_num)) + if pending_engine_signals: + with self.pending_layer0_signal_lock: + for engine_idx, prefilled_token_num in pending_engine_signals: + self.pending_layer0_signals[engine_idx] = (engine_idx, prefilled_token_num) + if pending_engine_signals: + logger.debug(f"cache_task_pending_layer0_signal: {pending_engine_signals}") + if ready_engine_signals: + logger.info( + f"Put batch_engine_signals {ready_engine_signals} into cache_prefilled_engine_ids_queue" + ) + self.cache_prefilled_engine_ids_queue.put(ready_engine_signals) except Exception as e: logger.error(f"Consume signals get exception: {e}") @@ -917,6 +954,20 @@ def _handle_connect_task(self): except Exception as e: logger.error(f"handle_connect_task has exception: {e}, {traceback.format_exc()}") + def _maybe_wait_for_cache_task(self, engine_index): + # If cache messager does not get cache task from engine, just hang here for now + wait_step = 1 + sleep_seconds = 0.005 + + while engine_index not in self.idx_cache_task_dict: + time.sleep(sleep_seconds) + wait_step += 1 + + if wait_step % 400 == 0: + logger.warning( + f"waiting cache task for engine_index: {engine_index}, cost_time: {wait_step * 0.005:.2f} s" + ) + def main(): device = args.device_id diff --git a/fastdeploy/cache_manager/cache_tasks.py b/fastdeploy/cache_manager/cache_tasks.py index fe15263827a..d50c809c0c6 100644 --- a/fastdeploy/cache_manager/cache_tasks.py +++ b/fastdeploy/cache_manager/cache_tasks.py @@ -15,7 +15,7 @@ """ from dataclasses import dataclass -from typing import List +from typing import List, Optional @dataclass(frozen=True, kw_only=True) @@ -35,3 +35,8 @@ class ReadStorageTask(CacheTask): @dataclass(frozen=True, kw_only=True) class WriteStorageTask(CacheTask): timeout: float = 30.0 + # Used in FD_AS_ONLY_FLUSH mode to indicate whether cache is present on this node. + # True = cache exists (request finish), False = cache gone (CPU eviction), None = not applicable. + flush_cache_exists: Optional[bool] = None + # Block index to start the write/flush operation from. Defaults to 0 (all blocks). + start_write_block_idx: int = 0 diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 85a113adf66..d3f26511372 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -24,6 +24,7 @@ import threading import time import traceback +import weakref from typing import List import numpy as np @@ -48,7 +49,7 @@ FileStore, MooncakeStore, ) -from fastdeploy.config import CacheConfig, SpeculativeConfig +from fastdeploy.config import CacheConfig, RoutingReplayConfig, SpeculativeConfig from fastdeploy.engine.request import ControlRequest, ControlResponse from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus from fastdeploy.inter_communicator.fmq import FMQ @@ -129,7 +130,11 @@ def parse_args(): ) parser.add_argument("--model_path", type=str, help="The path of model") + # Routing replay (R3) — single JSON arg, mirrors SpeculativeConfig pattern + parser.add_argument("--routing_replay_config", type=json.loads, default="{}", help="Routing replay config JSON") + args = parser.parse_args() + args.routing_replay_config = RoutingReplayConfig(args.routing_replay_config) return args @@ -208,9 +213,12 @@ def __init__(self, args): self.tansfer_done_queue = queue.Queue() # 用来告知任务执行完毕 self.ctrl_output_queue = None - address = (args.pod_ip, args.cache_queue_port) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + engine_cache_queue_address = (args.pod_ip, args.cache_queue_port) + else: + engine_cache_queue_address = f"/dev/shm/fd_task_queue_{args.cache_queue_port}.sock" self.cache_task_queue = EngineCacheQueue( - address=address, + address=engine_cache_queue_address, is_server=False, num_client=args.mp_num, client_id=self.rank, @@ -241,6 +249,25 @@ def __init__(self, args): self._init_cpu_cache() if self.storage_backend_type is not None: self._init_storage(args) + + # Initialize auxiliary data specs (e.g., routing replay) + self.aux_data_specs = {} + self.routing_host_view = None + self.routing_swap_buffer = None + self.routing_replay_config = args.routing_replay_config + self.engine_worker_queue_port = args.engine_worker_queue_port + self._init_routing_aux_data() + + # Register finalizer to release routing SharedMemory on process exit. + # Must use a static method — callback must NOT hold a reference to self, + # otherwise the object can never be GC'd and the finalizer won't fire. + self._finalizer = weakref.finalize( + self, + CacheTransferManager._cleanup_routing_resources, + self.routing_swap_buffer, + self.routing_host_view, + ) + self._init_control() cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32) @@ -307,6 +334,185 @@ def __init__(self, args): ) self.cache_transfer_inited_signal.value[self.rank] = 1 + def _init_routing_aux_data(self): + """Initialize routing auxiliary data buffers for swap sync.""" + routing_replay_config = self.routing_replay_config + if not routing_replay_config.enable_routing_replay: + return + + try: + from fastdeploy.cache_manager.cache_data import AuxBlockDataSpec + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingHostBufferView, + RoutingSwapBuffer, + ) + + num_moe_layers = routing_replay_config.num_moe_layers + moe_top_k = routing_replay_config.moe_top_k + routing_dtype = routing_replay_config.routing_dtype + + if num_moe_layers == 0 or moe_top_k == 0: + return + + spec = AuxBlockDataSpec( + name="routing", + num_layers=num_moe_layers, + per_token_size=moe_top_k, + block_size=self.block_size, + dtype=routing_dtype, + ) + + # Create routing swap buffer (for CPU blocks). + # Only rank 0 needs it — _swap_routing() only runs on rank 0. + if self.num_cpu_blocks > 0 and self.rank == 0: + dp_suffix = str(self.engine_worker_queue_port) + self.routing_swap_buffer = RoutingSwapBuffer( + num_cpu_blocks=self.num_cpu_blocks, + block_size=self.block_size, + num_moe_layers=num_moe_layers, + top_k=moe_top_k, + dtype=routing_dtype, + dp_suffix=dp_suffix, + ) + spec.swap_buffer = self.routing_swap_buffer + + # Attach to routing host buffer (SharedMemory created by Engine) + dp_suffix = str(self.engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + max_num_kv_tokens = self.num_gpu_blocks * self.block_size + shape = (max_num_kv_tokens, num_moe_layers, moe_top_k) + try: + self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=routing_dtype, shm_name=shm_name) + logger.info(f"[R3] CTM attached to RoutingHostBuffer: {shm_name}") + except FileNotFoundError: + logger.warning(f"[R3] CTM RoutingHostBuffer {shm_name} not found") + + self.aux_data_specs["routing"] = spec + logger.info(f"[R3] CTM registered routing aux data: layers={num_moe_layers}, top_k={moe_top_k}") + + except Exception as e: + logger.warning(f"[R3] CTM failed to init routing aux data: {e}") + + @staticmethod + def _cleanup_routing_resources(routing_swap_buffer, routing_host_view): + """Release routing SharedMemory on process exit. Called by weakref.finalize.""" + if routing_swap_buffer is not None: + routing_swap_buffer.close() + if routing_host_view is not None: + routing_host_view.close() + + def _swap_routing(self, gpu_block_ids, cpu_block_ids, direction): + """ + Swap routing data between routing_host_buffer and routing_swap_buffer. + Pure CPU-to-CPU numpy memcpy, no GPU DMA. + Only rank 0 performs this (routing buffers are cross-rank SharedMemory). + """ + if self.routing_host_view is None or self.routing_swap_buffer is None: + logger.warning( + f"[R3] _swap_routing skipped: host_view={self.routing_host_view is not None}, " + f"swap_buffer={self.routing_swap_buffer is not None}" + ) + return + if self.rank > 0: + return + bs = self.block_size + for gpu_bid, cpu_bid in zip(gpu_block_ids, cpu_block_ids): + gpu_start = gpu_bid * bs + gpu_end = gpu_start + bs + cpu_start = cpu_bid * bs + cpu_end = cpu_start + bs + if direction == "to_cpu": + self.routing_swap_buffer.buffer[cpu_start:cpu_end] = self.routing_host_view.buffer[gpu_start:gpu_end] + elif direction == "to_gpu": + self.routing_host_view.buffer[gpu_start:gpu_end] = self.routing_swap_buffer.buffer[cpu_start:cpu_end] + else: + raise ValueError(f"[R3] _swap_routing: unknown direction '{direction}', expected 'to_cpu' or 'to_gpu'") + logger.info( + f"[R3] _swap_routing {direction}: {len(gpu_block_ids)} blocks, " + f"gpu_ids={gpu_block_ids[:3]}{'...' if len(gpu_block_ids) > 3 else ''}, " + f"cpu_ids={cpu_block_ids[:3]}{'...' if len(cpu_block_ids) > 3 else ''}" + ) + + def _write_routing_to_storage(self, task_keys, gpu_block_ids): + """ + Write routing data from routing_host_buffer to storage backend. + Only for mooncake/file backends; only tp_rank=0 writes routing. + """ + if self.routing_host_view is None or self.rank != 0: + return + if self.storage_backend_type not in ("mooncake", "file"): + return + + try: + spec = self.aux_data_specs.get("routing") + if spec is None or not spec.enabled: + return + + bs = self.block_size + routing_keys = [] + routing_ptrs = [] + routing_sizes = [] + per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize + + for block_hash, gpu_bid in zip(task_keys, gpu_block_ids): + key = spec.get_storage_key(self.key_prefix, block_hash, self.rank) + start = gpu_bid * bs + end = start + bs + block_data = self.routing_host_view.buffer[start:end] + if not block_data.flags["C_CONTIGUOUS"]: + block_data = np.ascontiguousarray(block_data) + routing_keys.append(key) + routing_ptrs.append(block_data.ctypes.data) + routing_sizes.append(per_block_bytes) + + if routing_keys: + self.storage_backend.batch_set( + keys=routing_keys, target_locations=routing_ptrs, target_sizes=routing_sizes + ) + logger.debug(f"[R3] Wrote {len(routing_keys)} routing blocks to storage") + except Exception as e: + logger.warning(f"[R3] Failed to write routing to storage: {e}") + + def _read_routing_from_storage(self, task_keys, gpu_block_ids): + """ + Read routing data from storage backend into routing_host_buffer. + Only for mooncake/file backends; only tp_rank=0 reads routing. + """ + if self.routing_host_view is None or self.rank != 0: + return + if self.storage_backend_type not in ("mooncake", "file"): + return + + try: + spec = self.aux_data_specs.get("routing") + if spec is None or not spec.enabled: + return + + bs = self.block_size + per_block_bytes = bs * spec.num_layers * spec.per_token_size * np.dtype(spec.dtype).itemsize + + for block_hash, gpu_bid in zip(task_keys, gpu_block_ids): + key = spec.get_storage_key(self.key_prefix, block_hash, self.rank) + start = gpu_bid * bs + end = start + bs + target_slice = self.routing_host_view.buffer[start:end] + if not target_slice.flags["C_CONTIGUOUS"]: + # Need contiguous target for ctypes pointer + tmp = np.ascontiguousarray(target_slice) + result = self.storage_backend.get( + key=key, target_location=tmp.ctypes.data, target_size=per_block_bytes + ) + if result is not None and result >= 0: + self.routing_host_view.buffer[start:end] = tmp + else: + self.storage_backend.get( + key=key, target_location=target_slice.ctypes.data, target_size=per_block_bytes + ) + + logger.debug(f"[R3] Read {len(task_keys)} routing blocks from storage") + except Exception as e: + logger.warning(f"[R3] Failed to read routing from storage: {e}") + def _init_control(self): dp_rank = self.local_data_parallel_id tp_rank = self.rank @@ -320,7 +526,8 @@ def _init_storage(self, args): try: # TODO: support cache scale for other backend if self.has_cache_scale and self.storage_backend_type is not None: - if self.storage_backend_type not in ["mooncake"]: + is_as_only_flush = envs.FD_AS_ONLY_FLUSH and self.storage_backend_type == "attention_store" + if not is_as_only_flush and self.storage_backend_type not in ["mooncake"]: raise ValueError( f"Unsupported storage backend ({self.storage_backend_type}) " "when cache quantization is block_wise_fp8" @@ -348,6 +555,7 @@ def _init_storage(self, args): * self.cache_item_bytes, device_id=self.device, dp_id=self.local_data_parallel_id, + splitwise_role=getattr(args, "splitwise_role", "mixed"), ) logger.info("Initialized attention store successfully!") elif args.kvcache_storage_backend == "file": @@ -535,23 +743,28 @@ def _init_gpu_cache(self): logger.info("GPU KV cache is initialized") def _clear_gpu_cache(self): + if self.create_cache_tensor: logger.debug("Waiting for gpu runner to unlink cuda ipc") while self.cache_ready_signal.value[self.rank] != 0: time.sleep(0.1) logger.debug("Stop waiting! gpu runner has unlinked cuda ipc") - self.gpu_cache_kvs.clear() - self.gpu_cache_k_tensors.clear() - self.gpu_cache_v_tensors.clear() - if hasattr(self, "gpu_cache_scales_k_tensors"): - self.gpu_cache_scales_k_tensors.clear() - if hasattr(self, "gpu_cache_scales_v_tensors"): - self.gpu_cache_scales_v_tensors.clear() - paddle.device.cuda.empty_cache() else: for name, tensor in self.gpu_cache_kvs.items(): unset_data_ipc(tensor, name, True, False) logger.debug("Successfully unlinked gpu caches cuda ipc") + + self.gpu_cache_kvs.clear() + self.gpu_cache_k_tensors.clear() + self.gpu_cache_v_tensors.clear() + if hasattr(self, "gpu_cache_scales_k_tensors"): + self.gpu_cache_scales_k_tensors.clear() + if hasattr(self, "gpu_cache_scales_v_tensors"): + self.gpu_cache_scales_v_tensors.clear() + paddle.set_flags({"FLAGS_selected_gpus": f"{self.device}"}) + paddle.device.cuda.empty_cache() + + if not self.create_cache_tensor: self.cache_ready_signal.value[self.rank] = 0 while np.sum(self.cache_ready_signal.value) != 0: @@ -809,6 +1022,9 @@ def read_storage_task(self, task: ReadStorageTask): logger.info( f"Successfully read {len(valid_gpu_block_ids)} blocks from cache storage for task {task.task_id}" ) + # Read routing data from storage for matched blocks + matched_keys = task.keys[: len(valid_gpu_block_ids)] + self._read_routing_from_storage(matched_keys, valid_gpu_block_ids) except Exception as e: logger.error( f"Failed to read cache for task {task.task_id}, error: {e}, traceback: {traceback.format_exc()}" @@ -915,13 +1131,42 @@ def _run_write_back_storage( target_sizes.extend([self.scale_buffer_stride_bytes] * block_num * 2) start_time = time.time() - self.storage_backend.batch_set(keys=keys, target_locations=target_locations, target_sizes=target_sizes) + result = self.storage_backend.batch_set( + keys=keys, target_locations=target_locations, target_sizes=target_sizes + ) write_cost_time = time.time() - start_time + # Per-block success validation (same pattern as _run_read_storage) + # batch_set returns List[int]: 0 = success, negative = error + if k_scale_keys and v_scale_keys: + k_result = result[:block_num] + v_result = result[block_num : 2 * block_num] + k_scale_result = result[2 * block_num : 3 * block_num] + v_scale_result = result[3 * block_num :] + success_block_num = 0 + for k, v, ks, vs in zip(k_result, v_result, k_scale_result, v_scale_result): + if not (k == 0 and v == 0 and ks == 0 and vs == 0): + break + success_block_num += 1 + else: + k_result = result[:block_num] + v_result = result[block_num : 2 * block_num] + success_block_num = 0 + for k, v in zip(k_result, v_result): + if not (k == 0 and v == 0): + break + success_block_num += 1 + + if success_block_num < block_num: + logger.error( + f"_run_write_back_storage partial failure: " + f"{success_block_num}/{block_num} blocks written, task_id: {task_id}" + ) + logger.debug( f"_run_write_back_storage, swap_cost_time: {swap_cost_time:.6f}s, write_cost_time: {write_cost_time:.6f}s" ) - return block_num + return success_block_num elif self.storage_backend_type == "attention_store": key_cache = [] @@ -944,6 +1189,34 @@ def _run_write_back_storage( ) return 0 + def _flush_only_storage_task(self, task: WriteStorageTask): + """ + AS-only flush mode: skip actual storage write, only report cache index to AttentionStore. + Used when FD_AS_ONLY_FLUSH is enabled — AS acts as index-only (no data storage). + + Args: + task: WriteStorageTask with flush_cache_exists indicating cache state: + True/None = cache present on this node (request finish) + False = cache gone from this node (eviction) + """ + try: + if (self.rank == 0) and self.storage_backend_type == "attention_store": + reside_in_gpu = task.flush_cache_exists if task.flush_cache_exists is not None else True + self.storage_backend.flush_token_index( + task.task_id, task.token_ids, task.start_write_block_idx, reside_in_gpu + ) + logger.info( + f"[AS_ONLY_FLUSH] flush token index reside_in_gpu={reside_in_gpu} " + f"start_block_idx={task.start_write_block_idx} for task {task.task_id}" + ) + except Exception as e: + logger.warning(f"[AS_ONLY_FLUSH] Failed to flush token index for task {task.task_id}, error: {e}") + result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, []) + self.cache_task_queue.swap_to_storage_barrier.wait() + if self.rank == 0: + self.cache_task_queue.swap_to_storage_barrier.reset() + self.cache_task_queue.put_transfer_done_signal(result) + def write_back_storage_task(self, task: WriteStorageTask): """ Write cache to the storage backend from the GPU memory. @@ -952,6 +1225,9 @@ def write_back_storage_task(self, task: WriteStorageTask): self.storage_backend ), f"storage_backend not initialized, storage_backend_type: {self.storage_backend_type}" + if envs.FD_AS_ONLY_FLUSH: + return self._flush_only_storage_task(task) + try: gpu_block_ids = task.gpu_block_ids.copy() cpu_block_ids = [i for i in range(len(gpu_block_ids))] @@ -975,14 +1251,13 @@ def write_back_storage_task(self, task: WriteStorageTask): if match_block_num >= len(k_cache_keys): logger.info(f"No uncached keys found for task {task.task_id}") - gpu_block_ids = [] else: try: k_cache_keys = k_cache_keys[match_block_num:] v_cache_keys = v_cache_keys[match_block_num:] k_scale_keys = k_scale_keys[match_block_num:] if k_scale_keys else None v_scale_keys = v_scale_keys[match_block_num:] if v_scale_keys else None - gpu_block_ids = gpu_block_ids[match_block_num:] + write_gpu_block_ids = gpu_block_ids[match_block_num:] cpu_block_ids = cpu_block_ids[match_block_num:] # TODO: support timeout with actual block count write_block_num = self._run_write_back_storage( @@ -993,16 +1268,28 @@ def write_back_storage_task(self, task: WriteStorageTask): v_cache_keys, k_scale_keys, v_scale_keys, - gpu_block_ids, + write_gpu_block_ids, cpu_block_ids, task.timeout, ) logger.info( f"Successfully wrote {write_block_num} blocks to cache storage for task {task.task_id}" ) + # Check for partial write failure + if write_block_num < len(write_gpu_block_ids): + logger.error( + f"Partial write failure for task {task.task_id}: " + f"{write_block_num}/{len(write_gpu_block_ids)} blocks written" + ) + # Report: match_block_num (already cached) + write_block_num (newly written) + gpu_block_ids = gpu_block_ids[: match_block_num + write_block_num] + # Write routing data to storage only for actually-written blocks + written_block_ids = write_gpu_block_ids[:write_block_num] + remaining_keys = task.keys[match_block_num : match_block_num + len(written_block_ids)] + self._write_routing_to_storage(remaining_keys, written_block_ids) except Exception as e: logger.error(f"Error in write back storage task: {e}, traceback:{traceback.format_exc()}") - gpu_block_ids = [] + gpu_block_ids = gpu_block_ids[:match_block_num] finally: try: if (self.rank == 0) and self.storage_backend_type == "attention_store": @@ -1015,14 +1302,19 @@ def write_back_storage_task(self, task: WriteStorageTask): result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, gpu_block_ids) self.cache_task_queue.swap_to_storage_barrier.wait() - if self.rank == 0: # 只有当rank为0时执行同步操作 - self.cache_task_queue.swap_to_storage_barrier.reset() - self.cache_task_queue.put_transfer_done_signal(result) # 发送传输完成信号 - logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") + self.cache_task_queue.put_transfer_done_signal(result) + logger.debug(f"write_back_storage_task: put_transfer_done_signal {result}") except Exception as e: logger.error( f"An error occurred in write_back_storage_task, " f"error: {e}, traceback:\n{traceback.format_exc()}" ) + # Prevent caller from blocking forever: send empty done signal + try: + result = (CacheStatus.GPU2STORAGE, task.task_id, task.keys, []) + self.cache_task_queue.swap_to_storage_barrier.wait() + self.cache_task_queue.put_transfer_done_signal(result) + except Exception as barrier_err: + logger.error(f"Failed to send failure signal for task {task.task_id}: {barrier_err}") def _do_swap_to_cpu_task( self, @@ -1384,6 +1676,10 @@ def _transfer_data( 0, ) + # Routing: routing_host_buffer → routing_swap_buffer + if "routing" in self.aux_data_specs: + self._swap_routing(gpu_block_ids, cpu_block_ids, "to_cpu") + elif event_type.value == CacheStatus.SWAP2GPU.value: swap_cache_all_layers( self.gpu_cache_k_tensors, @@ -1422,6 +1718,11 @@ def _transfer_data( self.device, 1, ) + + # Routing: routing_swap_buffer → routing_host_buffer + if "routing" in self.aux_data_specs: + self._swap_routing(gpu_block_ids, cpu_block_ids, "to_gpu") + else: logger.warning( f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported" diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 52cd83682eb..b21b4349172 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -37,6 +37,8 @@ from fastdeploy.engine.request import Request from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.trace.constants import LoggingEventName +from fastdeploy.trace.trace_logger import print as trace_print from fastdeploy.utils import get_hash_str, get_logger logger = get_logger("prefix_cache_manager", "cache_manager.log") @@ -95,6 +97,7 @@ def __init__( self.kvcache_storage_backend = self.cache_config.kvcache_storage_backend self.write_policy = self.cache_config.write_policy self.task_write_back_event = {} + self.storage_write_back_result = {} self.task_prefetch_event = {} self.storage_prefetch_block_ids = {} @@ -211,8 +214,12 @@ def launch_cache_manager( create=True, ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + engine_cache_queue_address = (pod_ip, cache_config.local_cache_queue_port) + else: + engine_cache_queue_address = f"/dev/shm/fd_task_queue_{cache_config.local_cache_queue_port}.sock" self.cache_task_queue = EngineCacheQueue( - address=(pod_ip, cache_config.local_cache_queue_port), + address=engine_cache_queue_address, authkey=b"cache_queue_service", is_server=False, num_client=tensor_parallel_size, @@ -290,9 +297,18 @@ def launch_cache_manager( val_cache_arg_str = f" --value_cache_shape {val_shape_str}" if cache_config.kvcache_storage_backend: storage_arg_str = f" --kvcache_storage_backend {cache_config.kvcache_storage_backend}" + if not self.enable_splitwise: + storage_arg_str += " --create_cache_tensor" else: storage_arg_str = " " + # Compute routing replay args for CTM — single JSON arg + routing_replay_config = getattr(self.config, "routing_replay_config", None) + if routing_replay_config is not None and routing_replay_config.enable_routing_replay: + routing_arg_str = f" --routing_replay_config '{routing_replay_config.to_json_string()}'" + else: + routing_arg_str = "" + if self.cache_config.num_cpu_blocks > 0 or self.cache_config.kvcache_storage_backend: for i in range(tensor_parallel_size): launch_cmd = ( @@ -300,6 +316,7 @@ def launch_cache_manager( + visible_devices + " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0" + f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}" + + f" FD_AS_ONLY_FLUSH={int(envs.FD_AS_ONLY_FLUSH)}" + f" {sys.executable} {py_path}" + f" --device_id {int(device_ids[i])}" + f" --rank {i}" @@ -319,11 +336,11 @@ def launch_cache_manager( + f" --rdma_port {cache_config.local_rdma_comm_ports[i] if cache_config.local_rdma_comm_ports is not None else '0'}" + f" --speculative_config '{self.speculative_config.to_json_string()}'" + f" --default_dtype '{self.config.model_config.dtype}'" - + (" --create_cache_tensor" if not self.enable_splitwise else "") + storage_arg_str + f" --write_policy {cache_config.write_policy}" + f" --max_model_len {self.config.model_config.max_model_len}" + f" --model_path {self.config.model_config.model}" + + routing_arg_str + f" >{log_dir}/launch_cache_transfer_manager_{int(device_ids[i])}.log 2>&1" ) logger.info(f"Launch cache transfer manager, command:{launch_cmd}") @@ -715,6 +732,8 @@ def update_cache_blocks(self, task, block_size, num_computed_tokens): req_id = task.request_id last_node, num_cached_tokens = self.req_to_radix_tree_info[req_id] can_cache_computed_tokens = num_computed_tokens - num_computed_tokens % block_size + if can_cache_computed_tokens <= num_cached_tokens: + return if req_id in self.leaf_req_map[last_node]: # delete old leaf record, update later self.leaf_req_map[last_node].remove(req_id) logger.debug( @@ -844,7 +863,7 @@ def request_match_blocks(self, task: Request, block_size, *args): storage_match_token_num = 0 match_storage_block_ids = [] - if self.kvcache_storage_backend and no_match_token_num >= block_size: + if self.kvcache_storage_backend and no_match_token_num >= block_size and not envs.FD_AS_ONLY_FLUSH: if not self.can_allocate_gpu_blocks(num_blocks=no_match_block_num, try_free_gpu_blocks=False): raise Exception( "request_match_blocks: Not enough GPU memory to allocate cache for matched Storage Cache" @@ -1157,6 +1176,7 @@ def write_cache_to_storage(self, request: Request): if not keys: return + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", "")) gpu_block_ids = request.block_tables[: len(keys)] logger.info(f"start write cache back to storage, req_id: {req_id}, block num: {len(keys)}") write_storage_task = WriteStorageTask( @@ -1167,9 +1187,16 @@ def write_cache_to_storage(self, request: Request): ) logger.debug(f"issue write storage task: {write_storage_task}") tic = time.time() - self.issue_write_back_storage_task(write_storage_task, is_sync=True) + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic - logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + if not success: + logger.error( + f"write cache back to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info(f"finish write cache back to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) def write_cache_to_storage_decode(self, request: Request): """ @@ -1237,6 +1264,7 @@ def write_cache_to_storage_decode(self, request: Request): # Incremental logic is handled by CacheTransferManager.write_back_storage_task() req_id = request.request_id logger.info(f"[D instance] start write cache to storage, req_id: {req_id}, block num: {len(keys)}") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", "")) write_storage_task = WriteStorageTask( task_id=req_id, @@ -1246,33 +1274,246 @@ def write_cache_to_storage_decode(self, request: Request): ) tic = time.time() - self.issue_write_back_storage_task(write_storage_task, is_sync=True) + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) cost_time = time.time() - tic - logger.info(f"[D instance] finish write cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + if not success: + logger.error( + f"[D instance] write cache to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info(f"[D instance] finish write cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s") + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) + + def _compute_pd_storage_keys(self, request: Request, input_token_ids: list): + """ + Compute cache keys (including :partial:N suffix for last incomplete block) + for PD storage-pool mode. Used by both write_all_cache_to_storage (P/D) and + read_cache_from_storage_for_pd (D) to ensure consistent key computation. + + Args: + request: The request object (needed for get_block_hash_extra_keys). + input_token_ids: The token IDs to compute keys for. + + Returns: + list: The computed hash keys for each block. + """ + keys = [] + prefix_block_key = [] + block_size = self.config.cache_config.block_size + mm_idx = 0 + + for i in range(0, len(input_token_ids), block_size): + block_token_ids = input_token_ids[i : i + block_size] + actual_token_num = len(block_token_ids) + + if actual_token_num < block_size: + # Last incomplete block: compute key with actual tokens + partial marker + key = get_hash_str(block_token_ids, prefix_block_key) + key = f"{key}:partial:{actual_token_num}" + keys.append(key) + else: + # Full block: compute key normally + mm_idx, extra_keys = self.get_block_hash_extra_keys( + request=request, + start_idx=i, + end_idx=i + block_size, + mm_idx=mm_idx, + ) + prefix_block_key.extend(extra_keys) + key = get_hash_str(block_token_ids, prefix_block_key) + keys.append(key) + + prefix_block_key = [key] + + return keys + + def write_all_cache_to_storage(self, request: Request, include_output=True): + """ + Write ALL token cache (including last incomplete block) to storage. + Used in PD storage-pool mode where P writes to storage instead of RDMA to D, + and D writes back all cache (including output tokens) on request completion. + + Unlike write_cache_to_storage_decode which skips incomplete blocks, this method + writes the last incomplete block by padding it to block_size in the storage key + computation (using a ':partial:N' suffix on the key). + + The actual GPU block is still full-sized, so swap_cache_layout works normally. + + Args: + request: The request object. + include_output: If True, include output_token_ids in the write (used by D). + If False, only write prompt_token_ids (used by P). + + Returns: + bool: True if all blocks written successfully, False otherwise. + """ + if self.kvcache_storage_backend is None: + return True + + # 1. Get complete token_ids + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + else: + token_ids = list(token_ids) + + input_token_ids = token_ids + request.output_token_ids if include_output else token_ids + + # 2. Calculate cache keys using shared helper + keys = self._compute_pd_storage_keys(request, input_token_ids) + + if not keys: + return True + + # 3. Get corresponding gpu_block_ids + gpu_block_ids = request.block_tables[: len(keys)] + + # 4. Construct WriteStorageTask and send + req_id = request.request_id + logger.info( + f"[PD Storage] start write all cache to storage, req_id: {req_id}, " + f"block num: {len(keys)}, total_tokens: {len(input_token_ids)}" + ) + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_START, request.request_id, getattr(request, "user", "")) + + write_storage_task = WriteStorageTask( + task_id=req_id, + keys=keys, + token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None, + gpu_block_ids=gpu_block_ids, + ) + + tic = time.time() + success = self.issue_write_back_storage_task(write_storage_task, is_sync=True) + cost_time = time.time() - tic + if not success: + logger.error( + f"[PD Storage] write all cache to storage FAILED, req_id: {req_id}, " + f"block num: {len(keys)}, cost_time: {cost_time:.6f}s" + ) + else: + logger.info( + f"[PD Storage] finish write all cache to storage, req_id: {req_id}, cost_time: {cost_time:.6f}s" + ) + trace_print(LoggingEventName.WRITE_CACHE_TO_STORAGE_END, request.request_id, getattr(request, "user", "")) + return success + + def read_cache_from_storage_for_pd(self, request: Request): + """ + PD storage-pool mode: D instance reads cache from storage that P wrote. + + This is different from request_match_blocks() storage read: + - Called on D instance after receiving first_token notification from P + - Reads ALL blocks (including last partial block) that P wrote to storage + - Target gpu_block_ids are D's pre-allocated blocks + + Returns: + list: gpu_block_ids if all blocks fetched successfully, + empty list if any block failed to fetch (caller should abort this request). + """ + if self.kvcache_storage_backend is None: + return [] + + # 1. Get token_ids (same as what P prefilled) + token_ids = request.prompt_token_ids + if isinstance(token_ids, np.ndarray): + token_ids = token_ids.tolist() + else: + token_ids = list(token_ids) + input_token_ids = token_ids + + # 2. Calculate cache keys using shared helper (same algorithm as write_all_cache_to_storage) + keys = self._compute_pd_storage_keys(request, token_ids) + + if not keys: + return [] + + # 3. gpu_block_ids = D's pre-allocated block_tables + gpu_block_ids = request.block_tables[: len(keys)] + + # 4. Issue ReadStorageTask + req_id = request.request_id + logger.info( + f"[PD Storage] D start read cache from storage, req_id: {req_id}, " + f"block num: {len(keys)}, total_tokens: {len(input_token_ids)}" + ) + + read_task = ReadStorageTask( + task_id=req_id, + keys=keys, + token_ids=input_token_ids if self.kvcache_storage_backend == "attention_store" else None, + gpu_block_ids=gpu_block_ids, + start_read_block_idx=0, + ) + + tic = time.time() + storage_block_ids = self.issue_prefetch_storage_task(read_task, is_sync=True) + cost_time = time.time() - tic + + if len(storage_block_ids) != len(keys): + logger.error( + f"[PD Storage] D failed to read all blocks from storage, req_id: {req_id}, " + f"matched blocks: {len(storage_block_ids)}/{len(keys)}, cost_time: {cost_time:.6f}s" + ) + return [] + else: + logger.info( + f"[PD Storage] D finish reading the cache of all blocks from storage, req_id: {req_id}, " + f"matched blocks: {len(storage_block_ids)}/{len(keys)}, cost_time: {cost_time:.6f}s" + ) + return storage_block_ids def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True): + """ + Issue a write-back storage task. + Returns True if all blocks written successfully (sync mode), True always (async mode). + """ if self.kvcache_storage_backend is None: - return + return True - if len(task.keys) != len(task.gpu_block_ids): + if not envs.FD_AS_ONLY_FLUSH and len(task.keys) != len(task.gpu_block_ids): err_msg = ( f"write_back_storage error: hash_keys({len(task.keys)}) != gpu_block_ids({len(task.gpu_block_ids)})" ) logger.error(err_msg) raise ValueError(err_msg) - self.task_write_back_event[task.task_id] = Event() + if is_sync: + self.task_write_back_event[task.task_id] = Event() self.cache_task_queue.put_transfer_task((CacheStatus.GPU2STORAGE, task)) if is_sync: - self.wait_write_storage_task(task.task_id) + return self.wait_write_storage_task(task.task_id, expected_block_num=len(task.gpu_block_ids)) + return True - def wait_write_storage_task(self, req_id): + def wait_write_storage_task(self, req_id, expected_block_num=0, timeout=60.0): """ - Sync write back task + Sync write back task. + Returns True if all expected blocks written successfully across all TP ranks. + + Args: + req_id: request ID + expected_block_num: number of blocks expected to be written + timeout: max wait time in seconds """ if req_id in self.task_write_back_event: - self.task_write_back_event[req_id].wait() + success = self.task_write_back_event[req_id].wait(timeout=timeout) del self.task_write_back_event[req_id] + if not success: + logger.error(f"[PD Storage] write storage task timeout after {timeout}s, req_id: {req_id}") + self.storage_write_back_result.pop(req_id, None) + return False + # Check actual written block count vs expected + written_block_ids = self.storage_write_back_result.pop(req_id, []) + actual_written = len(written_block_ids) + if expected_block_num > 0 and actual_written < expected_block_num: + logger.error( + f"[PD Storage] write storage incomplete: {actual_written}/{expected_block_num} blocks, " + f"req_id: {req_id}" + ) + return False + return True + return True def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True): """ @@ -1448,6 +1689,7 @@ def free_block_ids_async(self, need_block_num): hash_value_swap_node_ids_map = defaultdict(list) hash_value_gpu_block_ids_map = defaultdict(list) + hash_value_flush_info = {} # {input_hash_value: (token_ids, min_depth)} total_gpu_free_count = 0 while True: @@ -1460,6 +1702,10 @@ def free_block_ids_async(self, need_block_num): self.gpu_lru_leaf_set.remove(node) if self.cache_config.num_cpu_blocks < need_block_num: if node.shared_count == 0 and node.is_gpu_leaf_node: # 直接回收 + if envs.FD_AS_ONLY_FLUSH and self.kvcache_storage_backend == "attention_store": + key = node.input_hash_value + if key not in hash_value_flush_info or node.depth < hash_value_flush_info[key][1]: + hash_value_flush_info[key] = (node.input_ids, node.depth) self._handle_free_gpu_node_without_cpu(node) total_gpu_free_count += 1 cur_node = node @@ -1509,6 +1755,22 @@ def free_block_ids_async(self, need_block_num): f"free_block_ids_async: need_block_num {need_block_num}, free_block_num {total_gpu_free_count}." ) + if ( + envs.FD_AS_ONLY_FLUSH + and self.kvcache_storage_backend == "attention_store" + and hash_value_flush_info + ): + for input_hash_value, (token_ids, min_depth) in hash_value_flush_info.items(): + flush_task = WriteStorageTask( + task_id=str(uuid.uuid4()), + keys=[input_hash_value], + token_ids=token_ids, + gpu_block_ids=[], + flush_cache_exists=False, + start_write_block_idx=min_depth - 1, + ) + self.issue_write_back_storage_task(flush_task, is_sync=False) + # swap cache to cpu if hash_value_gpu_block_ids_map: self.cpu_free_future = None @@ -2184,8 +2446,16 @@ def recv_data_transfer_result(self): elif event_type.value == CacheStatus.GPU2STORAGE.value: logger.debug(f"recv_data_transfer_result: {data}") task_id, hash_keys, block_ids = data[1:] - if task_id in self.task_write_back_event: - self.task_write_back_event[task_id].set() + # Collect results from all TP ranks (same pattern as STORAGE2GPU path) + if task_id not in self.storage_write_back_result: + self.storage_write_back_result[task_id] = [] + saved_results = self.storage_write_back_result[task_id] + saved_results.append(block_ids) + if len(saved_results) == self.tensor_parallel_size: + # Take minimum across all ranks (conservative, same as read path) + self.storage_write_back_result[task_id] = min(saved_results, key=len) + if task_id in self.task_write_back_event: + self.task_write_back_event[task_id].set() else: ( event_type, diff --git a/fastdeploy/cache_manager/routing_cache_manager.py b/fastdeploy/cache_manager/routing_cache_manager.py new file mode 100644 index 00000000000..f086b6e5821 --- /dev/null +++ b/fastdeploy/cache_manager/routing_cache_manager.py @@ -0,0 +1,321 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import math +import multiprocessing +import multiprocessing.shared_memory +from typing import Optional + +import numpy as np + +from fastdeploy.utils import get_logger + +logger = get_logger("routing_cache_manager", "routing_cache_manager.log") + + +class RoutingHostBuffer: + """ + Manages routing_host_buffer (corresponds to KVCache GPU cache). + Indexed by gpu_block_id * block_size + offset. + Shared across processes via POSIX SharedMemory. + Each DP rank creates its own instance; name includes dp_suffix. + """ + + def __init__( + self, num_gpu_blocks: int, block_size: int, num_moe_layers: int, top_k: int, dtype: str, dp_suffix: str = "" + ): + max_num_gpu_tokens = num_gpu_blocks * block_size + self.shape = (max_num_gpu_tokens, num_moe_layers, top_k) + self.dtype = np.dtype(dtype) + self.block_size = block_size + total_bytes = int(np.prod(self.shape)) * self.dtype.itemsize + + self.shm_name = f"routing_host_buffer.{dp_suffix}" + # Clean up stale SharedMemory from previous crashed process + try: + stale = multiprocessing.shared_memory.SharedMemory(name=self.shm_name, create=False) + stale.close() + stale.unlink() + logger.warning(f"[R3] Cleaned up stale SharedMemory: {self.shm_name}") + except FileNotFoundError: + pass + self.shm = multiprocessing.shared_memory.SharedMemory( + create=True, size=max(total_bytes, 1), name=self.shm_name + ) + self.buffer = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + self.buffer[:] = -1 # unsigned wrap: uint8→255, uint16→65535, uint32→4294967295 + + self._owner = True + logger.info( + f"[R3] Created RoutingHostBuffer: shape={self.shape}, " + f"size={total_bytes / 1024:.1f} KB, name={self.shm_name}" + ) + + def close(self): + """Close and unlink SharedMemory. Only the owner (creator) unlinks.""" + self.shm.close() + if self._owner: + self.shm.unlink() + self._owner = False + + +class RoutingHostBufferView: + """Read/write view of routing_host_buffer (cross-process, does not own).""" + + def __init__(self, shape, dtype: str, shm_name: str): + self.shm = multiprocessing.shared_memory.SharedMemory(name=shm_name, create=False) + self.dtype = np.dtype(dtype) + self.buffer = np.ndarray(shape, dtype=self.dtype, buffer=self.shm.buf) + + def scatter(self, slot_mapping: np.ndarray, data: np.ndarray): + """Scatter GPU buffer data to corresponding slots (Worker calls this).""" + self.buffer[slot_mapping] = data + + def gather(self, slot_mapping: np.ndarray) -> np.ndarray: + """Gather data from specified slots (TokenProcessor calls this).""" + return self.buffer[slot_mapping].copy() + + def close(self): + self.shm.close() + + +class RoutingSwapBuffer: + """ + Manages routing_swap_buffer (corresponds to KVCache CPU cache). + Indexed by cpu_block_id * block_size + offset. + CacheTransferManager creates this; shared via SharedMemory. + """ + + def __init__( + self, num_cpu_blocks: int, block_size: int, num_moe_layers: int, top_k: int, dtype: str, dp_suffix: str = "" + ): + max_num_cpu_tokens = num_cpu_blocks * block_size + self.shape = (max_num_cpu_tokens, num_moe_layers, top_k) + self.dtype = np.dtype(dtype) + self.block_size = block_size + total_bytes = int(np.prod(self.shape)) * self.dtype.itemsize + + self.shm_name = f"routing_swap_buffer.{dp_suffix}" + # Clean up stale SharedMemory from previous crashed process + try: + stale = multiprocessing.shared_memory.SharedMemory(name=self.shm_name, create=False) + stale.close() + stale.unlink() + logger.warning(f"[R3] Cleaned up stale SharedMemory: {self.shm_name}") + except FileNotFoundError: + pass + self.shm = multiprocessing.shared_memory.SharedMemory( + create=True, size=max(total_bytes, 1), name=self.shm_name + ) + self.buffer = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf) + self.buffer[:] = -1 # unsigned wrap: uint8→255, uint16→65535, uint32→4294967295 + + self._owner = True + logger.info( + f"[R3] Created RoutingSwapBuffer: shape={self.shape}, " + f"size={total_bytes / 1024:.1f} KB, name={self.shm_name}" + ) + + def close(self): + """Close and unlink SharedMemory. Only the owner (creator) unlinks.""" + self.shm.close() + if self._owner: + self.shm.unlink() + self._owner = False + + +class RoutingSwapBufferView: + """Read/write view of routing_swap_buffer (cross-process, does not own).""" + + def __init__(self, shape, dtype: str, shm_name: str): + self.shm = multiprocessing.shared_memory.SharedMemory(name=shm_name, create=False) + self.dtype = np.dtype(dtype) + self.buffer = np.ndarray(shape, dtype=self.dtype, buffer=self.shm.buf) + + def close(self): + self.shm.close() + + +def split_request_id(request_id: str) -> str: + """ + Split the request id to get rollout id. + + request_id: "chatcmpl-request.user-uuid" + rollout_id: "request.user" + example: "chatcmpl-xxx_xxx_epoch_15:2:2:1-d9f16c5c-65f6-4815-b44d-14e2c581907c_0" + -> "xxx_xxx_epoch_15:2:2:1" + """ + chat_type, tmp_str = request_id.split("-", 1) + assert ( + chat_type == "chatcmpl" + ), "Rollout Routing Replay only supports chatcmpl. Please check request type and userid settings." + reversed_tmp_str = tmp_str[::-1].split("-", 5) + rollout_id = reversed_tmp_str[-1][::-1] + return rollout_id + + +class RoutingCacheManager: + """ + Engine-side stateless routing data manager. + Does NOT maintain request mapping — request state is fully managed by Scheduler. + Responsible for: SharedMemory creation/destruction, routing data gather, return mode dispatch. + """ + + def __init__(self, fd_config, num_gpu_blocks: int): + routing_replay_config = fd_config.routing_replay_config + self.num_moe_layers = routing_replay_config.num_moe_layers + self.moe_top_k = routing_replay_config.moe_top_k + self.routing_dtype = routing_replay_config.routing_dtype + self.only_last_turn = routing_replay_config.only_last_turn + self.use_fused_put = routing_replay_config.use_fused_put + self.debug_mode = routing_replay_config.debug_mode + self.block_size = fd_config.cache_config.block_size + self.return_mode = ( + routing_replay_config.routing_store_type + ) # "local" / "rdma" → p2pstore; "response" → attach to RequestOutput + + dp_suffix = str(fd_config.parallel_config.local_engine_worker_queue_port) + + # Create SharedMemory routing_host_buffer + self.host_buffer = RoutingHostBuffer( + num_gpu_blocks=num_gpu_blocks, + block_size=self.block_size, + num_moe_layers=self.num_moe_layers, + top_k=self.moe_top_k, + dtype=self.routing_dtype, + dp_suffix=dp_suffix, + ) + + # Host view for gather operations + self.host_view = RoutingHostBufferView( + shape=self.host_buffer.shape, + dtype=self.routing_dtype, + shm_name=self.host_buffer.shm_name, + ) + + # Initialize store wrapper for p2pstore mode + self._store_wrapper = None + if self.return_mode in ("local", "rdma"): + from fastdeploy.cache_manager.routing_store import StoreWrapper + + self._store_wrapper = StoreWrapper(fd_config=fd_config) + self._store_wrapper.start_store_warpper() + + logger.info( + f"[R3] RoutingCacheManager initialized: return_mode={self.return_mode}, " + f"host_buffer shape={self.host_buffer.shape}" + ) + + def gather_routing_for_request(self, block_table, seq_len: int) -> np.ndarray: + """ + Gather complete routing data for a request from routing_host_buffer. + + Args: + block_table: List of block IDs for the request + seq_len: Total sequence length + + Returns: + routing_data: [seq_len, num_moe_layers, top_k] numpy array + """ + num_blocks = math.ceil(seq_len / self.block_size) + block_ids = block_table[:num_blocks] + positions = np.arange(seq_len) + block_indices = positions // self.block_size + offsets = positions % self.block_size + slot_mapping = np.array(block_ids)[block_indices] * self.block_size + offsets + routing_data = self.host_view.gather(slot_mapping) + + if self.debug_mode: + expected_routing = np.arange(seq_len, dtype=routing_data.dtype)[:, None, None] + expected_routing = np.broadcast_to(expected_routing, (seq_len, self.num_moe_layers, self.moe_top_k)) + if not np.array_equal(routing_data, expected_routing): + # Find all mismatched tokens + mismatch_mask = (routing_data != expected_routing).any(axis=(1, 2)) + mismatched_token_indices = np.where(mismatch_mask)[0] + # Check for duplicate slots in gather + unique_slots, counts = np.unique(slot_mapping, return_counts=True) + num_duplicates = np.sum(counts > 1) + dup_info = "" + if num_duplicates > 0: + dup_indices = np.where(counts > 1)[0] + dup_slots = unique_slots[dup_indices] + dup_info = f", duplicate_slots={list(dup_slots)}" + logger.error( + f"[R3 Debug] Gather mismatch! seq_len={seq_len}, mismatched_tokens={len(mismatched_token_indices)}, " + f"slots=[{slot_mapping[0]}...{slot_mapping[-1]}]{dup_info}" + ) + logger.error(f"Mismatched token indices: {mismatched_token_indices}") + for idx in mismatched_token_indices: # Print all mismatches tokens + logger.error( + f" position={idx}, slot={slot_mapping[idx]}, " + f"expected={expected_routing[idx, 0, 0]}, actual={routing_data[idx, 0, 0]}" + ) + raise ValueError("[R3 Debug]Routing gather validation failed.") + else: + logger.debug( + f"[R3 Debug] Gather validation passed: seq_len={seq_len}, " + f"slots=[{slot_mapping[0]}...{slot_mapping[-1]}]" + ) + + return routing_data + + def on_request_finished(self, request_id: str, block_table, seq_len: int) -> Optional[np.ndarray]: + """ + Unified entry point when a request finishes. Called by TokenProcessor on EOS detection. + Scheduler/TokenProcessor passes request_id, block_table, seq_len. + + Returns: + - "response" mode: routing_data numpy array (caller attaches to RequestOutput) + - "local"/"rdma" mode: None (submitted to StoreWrapper internally) + """ + routing_data = self.gather_routing_for_request(block_table, seq_len) + + if self._store_wrapper is not None: + # P2PStore mode: submit to store + rollout_id = split_request_id(request_id) + # Transpose to [num_moe_layers, seq_len, top_k] for store compatibility + # TODO(gongshaotian): Delete redundant transpose + routing_data = np.ascontiguousarray(routing_data.transpose(1, 0, 2)) + + if self.use_fused_put: + self._store_wrapper.submit_put_task(routing_indices=routing_data, rollout_id=rollout_id) + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) + else: + for layer_id in range(self.num_moe_layers): + layer_buffer = routing_data[layer_id] + self._store_wrapper.submit_put_task( + routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id + ) + if self.only_last_turn: + self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id, layer_idx=layer_id) + return None + else: + # Response mode: return data for caller to attach to RequestOutput + return routing_data + + def reset(self): + """Reset SharedMemory buffer. Used during RL round cleanup.""" + self.host_buffer.buffer[:] = -1 + + def close(self): + """Clean up SharedMemory resources.""" + if self.host_view is not None: + self.host_view.close() + self.host_view = None + if self.host_buffer is not None: + self.host_buffer.close() + self.host_buffer = None diff --git a/fastdeploy/cache_manager/routing_store.py b/fastdeploy/cache_manager/routing_store.py new file mode 100644 index 00000000000..cd1dce19bd4 --- /dev/null +++ b/fastdeploy/cache_manager/routing_store.py @@ -0,0 +1,515 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import asyncio +import atexit +import functools +import multiprocessing +import os +import shutil +import threading +import time +import traceback +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Process, Queue +from typing import Optional, TypedDict + +import numpy as np +import paddle + +from fastdeploy.utils import get_logger + +logger = get_logger("routing_cache_manager", "routing_cache_manager.log") + +from fastdeploy.config import RoutingReplayConfig + + +class StoreTask(TypedDict): + task_type: str + key: str + data: np.ndarray + + +class StoreWrapper(object): + def __init__(self, fd_config) -> None: + super().__init__() + self.fd_config = fd_config + + # Initialize task queue + moe_layer_num = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index + max_num_seqs = fd_config.scheduler_config.max_num_seqs + self.queue_max_size = moe_layer_num * max_num_seqs * 1000 + + self.manager = multiprocessing.Manager() + self._task_queue = self.manager.Queue(maxsize=self.queue_max_size) + + self._monitor_thread: threading.Thread = None + self._stop_monitor = threading.Event() + + # Initialize consumer process + self._routing_store_process = StoreProcess( + task_queue=self._task_queue, + routing_replay_config=self.fd_config.routing_replay_config, + max_model_len=self.fd_config.model_config.max_model_len, + ) + self._store_process_running = False + + # Register atexit handler + atexit.register(self.shutdown) + + def shutdown(self): + """ """ + if not self._store_process_running: + return + self._store_process_running = False + + # Stop the monitor thread + self._stop_monitor.set() + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(timeout=3.0) + + # Put a sentinel value to signal the consumer to stop + if self._routing_store_process and self._routing_store_process.is_alive(): + try: + self._task_queue.put_nowait(None) + except Exception as e: + logger.info(f"Could not put sentinel into queue: {e}") + + if self._routing_store_process and self._routing_store_process.is_alive(): + # Wait for all tasks to be processed + self._routing_store_process.join(timeout=10.0) + if self._routing_store_process.is_alive(): + self._routing_store_process.close() + self._routing_store_process.join() + + self._task_queue.join() + self.manager.shutdown() + self._store_process_running = False + + def start_store_warpper(self): + """ """ + if self._store_process_running: + return + self._store_process_running = True + + # Start monitor thread + self._stop_monitor.clear() + self._monitor_thread = threading.Thread(target=self._monitor_queue_load, daemon=True) + self._monitor_thread.start() + + # Start Routing Store Wrapper in sub process + self._routing_store_process.start() + + def _monitor_queue_load(self): + """ """ + while not self._stop_monitor.is_set(): + time.sleep(2.0) + if not self._store_process_running: + break + qsize = self._task_queue.qsize() + + # Alarm when the task exceeds 80% of the queue capacity + if qsize > self.queue_max_size * 0.8: + logger.warning( + f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " + "Consider increasing max_workers or queue_max_size." + ) + logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") + + def submit_put_task(self, routing_indices: np.ndarray, rollout_id: str, layer_idx: int = None) -> None: + """Submit a put task to the task queue""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + + start_time = time.perf_counter() + if layer_idx is not None: + rdma_rollout_key = f"{rollout_id}_{layer_idx}" + else: + rdma_rollout_key = rollout_id + + task: StoreTask = {"task_type": "put", "key": rdma_rollout_key, "data": routing_indices} + + try: + self._task_queue.put_nowait(task) + except Exception: + raise RuntimeError(f"Queue is FULL. Dropping put task for key: {rdma_rollout_key}. ") + logger.info(f"[R3] Submit put task for key: {rdma_rollout_key}, cost time: {time.perf_counter()-start_time} s") + + def submit_clear_store_task(self) -> None: + """Submit clear store task""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + + start_time = time.perf_counter() + task: StoreTask = {"task_type": "clear_store", "key": None, "data": None} + + try: + self._task_queue.put_nowait(task) + # Wait for the task to be processed + self._task_queue.join() + except Exception: + raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") + logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") + + def submit_clear_prefix_batch_task(self, rollout_id, layer_idx: int = None) -> None: + """Submit clear prefix batch task""" + if not self._store_process_running: + raise RuntimeError("Store not started.") + prefix_batch_id = self.get_needed_clear_ids(rollout_id) + if prefix_batch_id is None: + return + start_time = time.perf_counter() + if layer_idx is not None: + rdma_rollout_key = f"{prefix_batch_id}_{layer_idx}" + else: + rdma_rollout_key = prefix_batch_id + + task: StoreTask = {"task_type": "clear_prefix_batch", "key": rdma_rollout_key, "data": None} + try: + self._task_queue.put_nowait(task) + except Exception: + raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") + logger.info( + f"[R3] Submit clear prefix batch task for key: {prefix_batch_id}, cost time: {time.perf_counter()-start_time} s" + ) + + def get_needed_clear_ids(self, rollout_id: str) -> Optional[str]: + """ + Generate the prefix IDs for all closed multi-round tasks. + rollout_id: "xxx_xxx_epoch_15:2:2:1" + example: xxx_xxx_data_id:gen_id:turn_id:segment_id + """ + reversed_segment_id, reversed_turn_id, reversed_prefix_gen_id = rollout_id[::-1].split(":", 2) + prefix_gen_id = reversed_prefix_gen_id[::-1] + turn_id = eval(reversed_turn_id[::-1]) + segment_id = eval(reversed_segment_id[::-1]) + + assert turn_id >= 0 and segment_id >= 0 + prefix_batch = None + if turn_id > 0: + prefix_batch = f"{prefix_gen_id}:{(turn_id-1)}:{segment_id}" + return prefix_batch + + +class StoreProcess(Process): + def __init__(self, task_queue: Queue, routing_replay_config: RoutingReplayConfig, max_model_len: int) -> None: + super().__init__() + self.max_model_len = max_model_len + self._task_queue = task_queue + self.routing_replay_config = routing_replay_config + self.max_workers = 5 + self._closed = False + + # Note: _routing_store and _event_loop_thread must be initialized in run() + # because they cannot be properly inherited after fork() + self._routing_store = None + self._event_loop_thread = None + + def run(self): + logger.info(f"[R3] Start Running Store Wrapper in sub process {os.getpid()}") + + # Initialize routing store in subprocess + self._routing_store = get_routing_store(routing_replay_config=self.routing_replay_config) + + # Initialize event loop thread in subprocess + self._event_loop_thread = AsyncEventLoopThread() + self._event_loop_thread.start() + if not self._event_loop_thread._started_event.wait(timeout=5.0): + raise RuntimeError("Failed to start async event loop thread in subprocess") + + clear_store_task = StoreTask({"task_type": "clear_store", "key": None, "data": None}) + self._task_queue.put_nowait(clear_store_task) + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + while not self._closed: + try: + task = self._task_queue.get() + if task is None: # Sentinel + self._task_queue.task_done() + break + + if task["task_type"] == "put": + future = executor.submit(self.process_put_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + elif task["task_type"] == "clear_store": + future = executor.submit(self.process_clear_store_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + elif task["task_type"] == "clear_prefix_batch": + future = executor.submit(self.process_clear_prefix_batch_task, task) + future.add_done_callback(lambda f: self._task_queue.task_done()) + except Exception as e: + self._task_queue.task_done() + raise RuntimeError(f"Error during processing task. {e}") + + logger.info("RoutingReplay Consumer Process Shutdown.") + + def process_put_task(self, store_task: StoreTask) -> None: + try: + # TODO(gongshaotian): delete this after trainer support dynamic len + store_task["data"] = self.pad_routing_indices(store_task["data"]) + coro_obj = self._routing_store.put(routing_key=store_task["key"], routing_indices=store_task["data"]) + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error submitting put task: {e}") + traceback.print_exc() + raise + + def process_clear_store_task(self, store_task: StoreTask) -> None: + try: + coro_obj = self._routing_store.clear_store() + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error during processing clear store task. {e}") + traceback.print_exc() + raise + + def process_clear_prefix_batch_task(self, store_task: StoreTask) -> None: + try: + coro_obj = self._routing_store.clear_prefix_batch(routing_prefix_key=store_task["key"]) + future = self._event_loop_thread.submit_coroutine( + coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) + ) + return future + except Exception as e: + logger.error(f"Error submitting clear_prefix_batch task: {e}") + traceback.print_exc() + raise + + def _on_async_task_completed(self, task, future): + """ """ + try: + # result = future.result() + logger.info(f"[R3] Async task completed: {task['task_type']}, key: {task['key']}") + except Exception as e: + logger.error(f"[R3] Async task failed: {task['task_type']}, key: {task['key']}, error: {e}") + traceback.print_exc() + raise + + def close(self): + """Close the store process""" + self._closed = True + if hasattr(self, "_event_loop_thread"): + self._event_loop_thread.stop() + + def pad_routing_indices(self, routing_indices: np.ndarray) -> np.ndarray: + """Pad routing indices of the request levevl to max model len""" + routing_shape = routing_indices.shape + if len(routing_shape) == 2: # [token, topk] + pad_array = np.full( + shape=[(self.max_model_len - routing_indices.shape[0]), routing_indices.shape[1]], + fill_value=-1, + dtype=routing_indices.dtype, + ) + return np.concatenate([routing_indices, pad_array], axis=0) + + elif len(routing_shape) == 3: # [layer, token, topk] + pad_array = np.full( + shape=[ + routing_indices.shape[0], + (self.max_model_len - routing_indices.shape[1]), + routing_indices.shape[2], + ], + fill_value=-1, + dtype=routing_indices.dtype, + ) + return np.concatenate([routing_indices, pad_array], axis=1) + else: + raise ValueError(f"Invalid routing indices shape: {routing_shape}") + + +class AsyncEventLoopThread(threading.Thread): + def __init__(self): + super().__init__(daemon=True) + self._loop = None + self._started_event = threading.Event() + self._closed = False + + def run(self): + """Run the async event loop""" + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + + # Set the event loop to be started + self._started_event.set() + logger.info("[EventLoopThread] Event loop started, running forever...") + + try: + self._loop.run_forever() + logger.info("[EventLoopThread] Event loop stopped") + except Exception as e: + logger.error(f"[EventLoopThread] Event loop exception: {e}") + traceback.print_exc() + finally: + logger.info("[EventLoopThread] Closing event loop") + self._loop.close() + + def submit_coroutine(self, coro, callback=None): + """Thread safely submit coroutine to event loop""" + if self._closed: + raise RuntimeError("Event loop thread is closed") + if not self._started_event.wait(timeout=5.0): + raise RuntimeError("Event loop failed to start within 5 seconds") + + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + + if callback: + + def wrapped_callback(f): + try: + callback(f) + except Exception as e: + logger.error(f"Error in callback: {e}") + traceback.print_exc() + + future.add_done_callback(wrapped_callback) + return future + + def stop(self): + """Stop the event loop""" + if not self._closed: + self._closed = True + if self._loop: + self._loop.call_soon_threadsafe(self._loop.stop) + + +class RoutingStoreBase(ABC): + """Base class for routing store""" + + def __init__(self, routing_replay_config: RoutingReplayConfig) -> None: + self.routing_replay_config = routing_replay_config + + @abstractmethod + async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: + """Put the routing indices into store""" + raise NotImplementedError + + @abstractmethod + async def clear_store( + self, + ): + """Clear the routing indices store""" + raise NotImplementedError + + @abstractmethod + async def clear_prefix_batch(self, routing_prefix_key: str): + """Clear the routing indices""" + raise NotImplementedError + + +class RoutingStoreLocal(RoutingStoreBase): + """Routing Store using local memory""" + + def __init__(self, routing_replay_config) -> None: + super().__init__(routing_replay_config=routing_replay_config) + self.local_store_dir = routing_replay_config.local_store_dir + os.makedirs(self.local_store_dir, exist_ok=True) + + async def put( + self, + routing_key: str, + routing_indices: np.ndarray, + ) -> None: + """Put the routing indices into store""" + # TODO(gongshaotian) covert ./store_dir/routing_key/layer_id.pdtensor to ./store_dir/routing_key.pdtensor + time_before_put = time.perf_counter() + + if len(routing_indices.shape) == 2: + re_layer_id, re_rollout_id = routing_key[::-1].split("_", 1) + rollout_id = re_rollout_id[::-1] + layer_id = re_layer_id[::-1] + request_path = os.path.join(self.local_store_dir, rollout_id) + file_path = os.path.join(request_path, f"layer_{layer_id}.pdtensor") + elif len(routing_indices.shape) == 3: + request_path = os.path.join(self.local_store_dir, routing_key) + file_path = os.path.join(request_path, f"{routing_key}.pdtensor") + else: + raise ValueError(f"Invalid routing indices shape: {routing_indices.shape}") + + paddle.save(routing_indices, file_path) + logger.info(f"[R3] The routing key {routing_key} put cost is {time.perf_counter()-time_before_put}s") + + async def clear_store(self): + """Clear the routing indices store""" + if os.path.isdir(self.local_store_dir): + shutil.rmtree(self.local_store_dir) + + logger.info("[R3] Clear routing store.") + + async def clear_prefix_batch(self, routing_prefix_key: str): + """Clear the routing indices""" + raise NotImplementedError + + +class RoutingStoreRDMA(RoutingStoreBase): + """Routing Store using RDMA""" + + def __init__(self, routing_replay_config) -> None: + super().__init__(routing_replay_config=routing_replay_config) + try: + # Only used in RLHF + from p2pstore import P2PClient, P2PConfig + except ModuleNotFoundError: + raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ") + + rdma_store_server = routing_replay_config.rdma_store_server + p2pConfig = P2PConfig(metadata_server=rdma_store_server) + self.p2p_client = P2PClient(p2pConfig) + + async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: + """Put the routing indices into store""" + time_before_put = time.perf_counter() + if len(routing_indices.shape) == 3: + # NOTE(gongshaotian) Fused put with bytes data + routing_bytes = routing_indices.tobytes() + result = await self.p2p_client.put(routing_key, routing_bytes) + else: + result = await self.p2p_client.put(routing_key, routing_indices) + logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") + return result + + async def clear_prefix_batch(self, routing_prefix_key: str): + time_before_clear = time.perf_counter() + result = await self.p2p_client.delete_batch([routing_prefix_key]) + logger.info( + f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" + ) + return result + + async def clear_store(self): + """Clear the routing indices store""" + time_before_clear = time.perf_counter() + result = await self.p2p_client.clear() + logger.info(f"[R3] Clear routing store cost is {time.perf_counter()-time_before_clear}s.") + return result + + +def get_routing_store(routing_replay_config: RoutingReplayConfig) -> RoutingStoreBase: + if routing_replay_config.routing_store_type == "local": + return RoutingStoreLocal(routing_replay_config=routing_replay_config) + elif routing_replay_config.routing_store_type == "rdma": + return RoutingStoreRDMA(routing_replay_config=routing_replay_config) + else: + raise ValueError( + f"Invalid routing store type: '{routing_replay_config.routing_store_type}'. " + "Valid types are: 'local', 'rdma'" + ) diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py index 466caef59e9..a938c043422 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/attention_store.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import os import time import traceback from dataclasses import dataclass @@ -25,12 +26,16 @@ KVCacheStorage, logger, ) +from fastdeploy.platforms import current_platform try: import attentionstore_sdk.api.common.common_pb2 as common_pb2 from attentionstore_sdk.sdk import AttentionStoreSDK, Tokens from attentionstore_sdk.utils.err import AttentionStoreSDKError + if current_platform.is_cuda(): + from attentionstore_sdk.client.client import AttentionType + _ATTENTIONSTORE_AVAILABLE = True except Exception: AttentionStoreSDK = None @@ -51,6 +56,7 @@ class AttentionStoreConfig: bytes_per_shard_layer_per_block: int = 1024 device_id: int = 0 dp_id: int = 0 + splitwise_role: str = "mixed" class AttentionStore(KVCacheStorage): @@ -62,19 +68,44 @@ def __init__(self, **args): self.config = AttentionStoreConfig(**args) try: + self.config.namespace = os.getenv("AS_NAMESPACE", self.config.namespace) + self.config.pod_name = os.getenv("AS_POD_NAME", self.config.pod_name) + if int(os.getenv("ENABLE_EP_DP_IN_FD", "1")): + self.config.pod_name = ( + self.config.pod_name + "_" + self.config.splitwise_role + "_" + str(self.config.dp_id) + ) + self.config.model_version = os.getenv("AS_MODEL_VERSION", self.config.model_version) logger.info(f"[INIT] Start initializing AttentionStoreSDK with config: {self.config}") - self.sdk = AttentionStoreSDK( - self.config.namespace, - self.config.pod_name, - self.config.model_version, - self.config.shard_id, - self.config.shard_num, - self.config.layer_num, - self.config.block_token_size, - self.config.bytes_per_shard_layer_per_block, - self.config.device_id, - self.config.dp_id, - ) + if current_platform.is_cuda(): + self.sdk = AttentionStoreSDK( + self.config.namespace, + self.config.pod_name, + self.config.model_version, + self.config.shard_id, + self.config.shard_num, + self.config.layer_num, + self.config.block_token_size, + self.config.bytes_per_shard_layer_per_block, + self.config.bytes_per_shard_layer_per_block, + self.config.device_id, + self.config.dp_id, + attention_type=AttentionType.MHA, + enable_as_kv_rw=True, + gpu_count=0, + ) + else: + self.sdk = AttentionStoreSDK( + self.config.namespace, + self.config.pod_name, + self.config.model_version, + self.config.shard_id, + self.config.shard_num, + self.config.layer_num, + self.config.block_token_size, + self.config.bytes_per_shard_layer_per_block, + self.config.device_id, + self.config.dp_id, + ) self.wait_for_sdk_ready(timeout=300, delta_t=5) logger.info("[INIT] ✅ AttentionStore is initialized successfully!") except Exception as e: @@ -120,15 +151,27 @@ def read( v_data_ptrs = [v.data_ptr() for v in val_cache] num = 0 try: - num = self.sdk.read( - list(range(self.config.layer_num)), - tokens, - start_read_block_idx, - k_data_ptrs, - v_data_ptrs, - gpu_block_ids, - timeout, - ) + if current_platform.is_cuda(): + num = self.sdk.read( + list(range(self.config.layer_num)), + tokens, + start_read_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + remote_addrs=None, + ) + else: + num = self.sdk.read( + list(range(self.config.layer_num)), + tokens, + start_read_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + ) logger.debug(f"[READ END] task_id: {task_id} read_blocks: {num}") except AttentionStoreSDKError: logger.error( @@ -154,15 +197,28 @@ def write( v_data_ptrs = [v.data_ptr() for v in val_cache] num = 0 try: - num = self.sdk.write( - list(range(self.config.layer_num)), - tokens, - start_write_block_idx, - k_data_ptrs, - v_data_ptrs, - gpu_block_ids, - timeout, - ) + if current_platform.is_cuda(): + num = self.sdk.write( + list(range(self.config.layer_num)), + tokens, + start_write_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + h2h_copy=False, + params=None, + ) + else: + num = self.sdk.write( + list(range(self.config.layer_num)), + tokens, + start_write_block_idx, + k_data_ptrs, + v_data_ptrs, + gpu_block_ids, + timeout, + ) logger.debug(f"[WRITE END] task_id: {task_id} written_blocks: {num}") except AttentionStoreSDKError: logger.error( diff --git a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py index 3fc10996a61..1a81cfd652f 100644 --- a/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py +++ b/fastdeploy/cache_manager/transfer_factory/mooncake_store/mooncake_store.py @@ -31,7 +31,14 @@ from fastdeploy.utils import get_host_ip DEFAULT_GLOBAL_SEGMENT_SIZE = 1024 * 1024 * 1024 # 1 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128MB +DEFAULT_LOCAL_BUFFER_SIZE = 1024 * 1024 # 1MB +DEFAULT_MC_MAX_MR_SIZE = 4 * 1024 * 1024 * 1024 # 4GB +MIN_MC_MAX_MR_SIZE = 1024 * 1024 * 1024 # 1GB +MAX_MC_MAX_MR_SIZE = 6 * 1024 * 1024 * 1024 # 6GB + + +def byte_to_gb(byte): + return byte / (1024 * 1024 * 1024) @dataclass @@ -111,9 +118,25 @@ def __init__(self, tp_rank=None): host_ip = get_host_ip() os.environ["MC_TCP_BIND_ADDRESS"] = host_ip logger.info(f"Set MC_TCP_BIND_ADDRESS to {host_ip}") - if os.environ.get("MC_MAX_MR_SIZE") is None: - os.environ["MC_MAX_MR_SIZE"] = "4294967296" # 4GB - logger.info("MC_MAX_MR_SIZE is not set, default to 4GB.") + + # Set MC_MAX_MR_SIZE for mooncake store to control the maximum mr size + self.mc_max_mr_size = int(os.environ.get("MC_MAX_MR_SIZE", 0)) + if self.mc_max_mr_size == 0: + self.mc_max_mr_size = DEFAULT_MC_MAX_MR_SIZE + logger.info(f"MC_MAX_MR_SIZE is not set, default to {byte_to_gb(DEFAULT_MC_MAX_MR_SIZE)} GB.") + elif self.mc_max_mr_size < MIN_MC_MAX_MR_SIZE: + self.mc_max_mr_size = MIN_MC_MAX_MR_SIZE + logger.info( + f"MC_MAX_MR_SIZE is smaller than {byte_to_gb(MIN_MC_MAX_MR_SIZE)} GB, set to {byte_to_gb(MIN_MC_MAX_MR_SIZE)} GB." + ) + elif self.mc_max_mr_size > MAX_MC_MAX_MR_SIZE: + self.mc_max_mr_size = MAX_MC_MAX_MR_SIZE + logger.info( + f"MC_MAX_MR_SIZE is larger than {byte_to_gb(MAX_MC_MAX_MR_SIZE)} GB, set to {byte_to_gb(MAX_MC_MAX_MR_SIZE)} GB." + ) + else: + logger.info(f"MC_MAX_MR_SIZE is set to {self.mc_max_mr_size} bytes.") + os.environ["MC_MAX_MR_SIZE"] = str(self.mc_max_mr_size) try: from mooncake.store import MooncakeDistributedStore @@ -129,6 +152,11 @@ def __init__(self, tp_rank=None): self.config = MooncakeStoreConfig.create() if self.tp_rank is not None: self.config.select_rdma_device(self.tp_rank) + if self.config.local_buffer_size > self.mc_max_mr_size: + raise ValueError( + f"local_buffer_size {self.config.local_buffer_size} must be " + f"smaller than mc_max_mr_size {self.mc_max_mr_size}" + ) logger.info(f"Mooncake Configuration loaded, {self.config}.") ret_code = self.store.setup( @@ -162,13 +190,38 @@ def warmup(self): self.store.remove(warmup_key) def register_buffer(self, buffer_ptr, buffer_size) -> None: - try: - ret_code = self.store.register_buffer(buffer_ptr, buffer_size) - if ret_code: - logger.error(f"failed to register buffer, error code: {ret_code}") - except TypeError as err: - logger.error("Failed to register buffer to Mooncake Store: %s", err) - raise TypeError("Mooncake Store Register Buffer Error.") from err + """Register a buffer with Mooncake Store. + If buffer_size exceeds mc_max_mr_size, the buffer is split into + multiple chunks, each registered separately. + cuda_host_alloc returns physically contiguous pinned memory, so + pointer offset arithmetic is valid for sub-region registration. + """ + max_mr_size = self.mc_max_mr_size + if buffer_size <= max_mr_size: + try: + ret_code = self.store.register_buffer(buffer_ptr, buffer_size) + assert ret_code == 0, f"failed to register buffer, error code: {ret_code}" + except TypeError as err: + logger.error("Failed to register buffer to Mooncake Store: %s", err) + raise TypeError("Mooncake Store Register Buffer Error.") from err + else: + num_chunks = (buffer_size + max_mr_size - 1) // max_mr_size + logger.info( + f"Registering buffer of {byte_to_gb(buffer_size):.2f}GB in {num_chunks} chunks " + f"(max_mr_size={byte_to_gb(max_mr_size):.2f}GB per chunk)" + ) + for i in range(num_chunks): + chunk_ptr = buffer_ptr + i * max_mr_size + chunk_size = min(max_mr_size, buffer_size - i * max_mr_size) + try: + ret_code = self.store.register_buffer(chunk_ptr, chunk_size) + assert ret_code == 0, ( + f"failed to register chunk {i}/{num_chunks}, " + f"size={byte_to_gb(chunk_size):.2f}GB, error code: {ret_code}" + ) + except TypeError as err: + logger.error("Failed to register chunk %d/%d to Mooncake Store: %s", i, num_chunks, err) + raise TypeError("Mooncake Store Register Buffer Error.") from err def set( self, diff --git a/fastdeploy/config.py b/fastdeploy/config.py index b15a6dc824b..1d32f386910 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -211,6 +211,7 @@ def __init__( self.enable_logprob = False self.max_logprobs = 20 self.logprobs_mode = "raw_logprobs" + self.enable_keep_sampling_mask = False self.redundant_experts_num = 0 self.seed = 0 self.quantization = None @@ -378,6 +379,9 @@ def override_name_from_config(self): # Because the ERNIE 4.5 config.json contains two sets of keys, adaptation is required. self.moe_num_shared_experts = self.n_shared_experts + if hasattr(self, "num_experts_per_tok") and not hasattr(self, "moe_k"): + self.moe_k = self.num_experts_per_tok + def read_from_env(self): """ Read configuration information from environment variables and update the object's attributes. @@ -671,6 +675,7 @@ def __init__( self.pod_ip: str = None # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). self.disable_custom_all_reduce: bool = False + self.enable_flashinfer_allreduce_fusion: bool = False for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) @@ -774,7 +779,7 @@ class SpeculativeConfig: "benchmark_mode": False, "enf_gen_phase_tag": False, "enable_draft_logprob": False, - "verify_strategy": "topp", + "verify_strategy": "target_match", "accept_policy": "normal", } @@ -1058,6 +1063,7 @@ def __init__( - None (default): capture sizes are inferred from llm config. - list[int]: capture sizes are specified as given.""" self.cudagraph_capture_sizes: Optional[list[int]] = None + self.flag_cudagraph_capture_sizes_initlized = False self.cudagraph_capture_sizes_prefill: list[int] = [1, 2, 4, 8] """ Number of warmup runs for cudagraph. """ self.cudagraph_num_of_warmups: int = 2 @@ -1108,13 +1114,27 @@ def __init__( self.check_legality_parameters() - def init_with_cudagrpah_size(self, max_capture_size: int = 0, max_capture_shape_prefill: int = 0) -> None: + def init_with_cudagrpah_size( + self, + max_capture_size: int = 0, + max_capture_shape_prefill: int = 0, + num_speculative_tokens: int = 0, + ) -> None: """ Initialize cuda graph capture sizes and pre-compute the mapping from batch size to padded graph size """ # Regular capture sizes - self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size] + if num_speculative_tokens != 0: + max_capture_size = max_capture_size * (num_speculative_tokens + 1) + if not self.flag_cudagraph_capture_sizes_initlized and num_speculative_tokens != 0: + self.cudagraph_capture_sizes = [ + size * (num_speculative_tokens + 1) + for size in self.cudagraph_capture_sizes + if (size * (num_speculative_tokens + 1)) <= max_capture_size + ] + else: + self.cudagraph_capture_sizes = [size for size in self.cudagraph_capture_sizes if size <= max_capture_size] self.cudagraph_capture_sizes_prefill = [ size for size in self.cudagraph_capture_sizes_prefill if size <= max_capture_shape_prefill ] @@ -1154,24 +1174,41 @@ def init_with_cudagrpah_size(self, max_capture_size: int = 0, max_capture_shape_ self.real_shape_to_captured_size_prefill[bs] = end self.real_shape_to_captured_size_prefill[self.max_capture_size_prefill] = self.max_capture_size_prefill + if num_speculative_tokens != 0: + real_bsz_to_captured_size = {} + for capture_size in self.cudagraph_capture_sizes: + dummy_batch_size = int(capture_size / (num_speculative_tokens + 1)) + real_bsz_to_captured_size[dummy_batch_size] = capture_size + + def expand_bsz_map(real_bsz_to_captured_size): + sorted_items = sorted(real_bsz_to_captured_size.items()) + result = {} + prev_bsz = 0 + for curr_bsz, cap in sorted_items: + for bsz in range(prev_bsz + 1, curr_bsz + 1): + result[bsz] = cap + prev_bsz = curr_bsz + return result + + self.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size) + + self.flag_cudagraph_capture_sizes_initlized = True + def _set_cudagraph_sizes( self, max_capture_size: int = 0, max_capture_shape_prefill: int = 0, - dec_token_per_query_per_step: int = 1, ): """ Calculate a series of candidate capture sizes, and then extract a portion of them as the capture list for the CUDA graph based on user input. """ - # Shape [1, 2, 4, 8, 16, ... 120, 128] * dec_token_per_query_per_step - draft_capture_sizes = [i * dec_token_per_query_per_step for i in [1, 2, 4]] + [ - 8 * i * dec_token_per_query_per_step for i in range(1, 17) - ] - # Shape [128, 144, ... 240, 256] * dec_token_per_query_per_step - draft_capture_sizes += [16 * i * dec_token_per_query_per_step for i in range(9, 17)] - # Shape [256, 288, ... 992, 1024] * dec_token_per_query_per_step - draft_capture_sizes += [32 * i * dec_token_per_query_per_step for i in range(9, 33)] + # Shape [1, 2, 4, 8, 16, ... 120, 128] + draft_capture_sizes = [i for i in [1, 2, 4]] + [8 * i for i in range(1, 17)] + # Shape [128, 144, ... 240, 256] + draft_capture_sizes += [16 * i for i in range(9, 17)] + # Shape [256, 288, ... 992, 1024] + draft_capture_sizes += [32 * i for i in range(9, 33)] draft_capture_sizes_prefill = draft_capture_sizes.copy() draft_capture_sizes.append(max_capture_size) @@ -1417,6 +1454,8 @@ def __init__( self.rsync_config: Optional[Dict[str, Any]] = None for key, value in args.items(): if hasattr(self, key): + if key == "rsync_config" and isinstance(value, str): + value = json.loads(value) setattr(self, key, value) def __str__(self) -> str: @@ -1713,6 +1752,9 @@ def __init__(self, args: dict): else: self.metrics_port = self.api_server_port + def __str__(self): + return json.dumps({key: value for key, value in self.__dict__.items()}) + class CommitConfig: """ @@ -1800,7 +1842,7 @@ def __init__(self, args) -> None: self.enable_routing_replay: bool = False - # Routing store type: local/rdma + # Routing return mode: "local" (file store) / "rdma" (P2PStore) / "response" (attach to RequestOutput) self.routing_store_type: str = "local" # Local routing store @@ -1815,17 +1857,110 @@ def __init__(self, args) -> None: # Fused routing of all layers self.use_fused_put: bool = False + # Debug mode: hack topk_ids to use position_ids for validation + self.debug_mode: bool = False + + # Auto-filled by FDConfig from ModelConfig (do not set manually) + self.routing_dtype: str = "" # "uint8" / "uint16" / "uint32" + self.num_moe_layers: int = 0 + self.moe_top_k: int = 0 + if args is not None: for key, value in args.items(): if hasattr(self, key) and value != "None": setattr(self, key, value) + def postprocess(self, model_config: "ModelConfig") -> None: + """Fill computed fields from ModelConfig. Must be called after model-specific + field unification (e.g. GLM's first_k_dense_replace → moe_layer_start_index).""" + if not self.enable_routing_replay: + return + self.num_moe_layers = model_config.num_hidden_layers - model_config.moe_layer_start_index + if model_config.architectures[0] == "Glm4MoeForCausalLM": + self.moe_top_k = model_config.num_experts_per_tok + else: + self.moe_top_k = model_config.moe_k + num_experts = model_config.moe_num_experts + model_config.moe_num_shared_experts + total_number = num_experts + 1 # +1 for reserved fill value + if total_number <= 255: + self.routing_dtype = "uint8" + elif total_number <= 65535: + self.routing_dtype = "uint16" + elif total_number <= 4294967295: + self.routing_dtype = "uint32" + else: + raise ValueError(f"num_experts {num_experts} exceeds uint32 range") + if self.debug_mode: + self.routing_dtype = "int64" + def to_json_string(self): """ Convert routing replay config to json string. """ return json.dumps({key: value for key, value in self.__dict__.items()}) + def __str__(self): + return self.to_json_string() + + +class BenchmarkMetricsConfig: + """Configuration for in-process benchmark metrics logger. + + Args (passed as JSON dict via --benchmark-metrics-config): + enable: Whether to enable the benchmark metrics logger. Default: False. + window_size: Number of recent requests to aggregate. 0 = all requests (cumulative). + window_mode: Window aggregation mode. Default: "sliding". + "sliding" = sliding window (keep last N records), + "tumbling" = tumbling window (clear and restart after every N records). + percentiles: Comma-separated percentile values to compute, e.g. "50,90,95,99". + metrics: Comma-separated metric names to report, or "all". + Available metrics (aligned with benchmark_serving.py --percentile-metrics): + ttft - Time to First Token (client arrival → first token) + s_ttft - Server TTFT (inference start → first token) + tpot - Time per Output Token (excluding first token) + s_itl - Infer Inter-token Latency + e2el - End-to-end Latency (client arrival → last token) + s_e2el - Server E2EL (inference start → last token) + s_decode - Decode speed (tokens/s, excluding first token) + input_len - Prefix cache hit token count ("Cached Tokens" in benchmark_serving) + s_input_len - Infer input length (total prompt tokens on inference side) + output_len - Output token length per request + """ + + _DEFAULTS = { + "enable": False, + "window_size": 0, + "window_mode": "sliding", + "percentiles": "50,90,95,99", + "metrics": "all", + } + + _ALL_METRICS = [ + "ttft", # Time to First Token + "s_ttft", # Server TTFT + "tpot", # Time per Output Token + "s_itl", # Infer Inter-token Latency + "e2el", # End-to-end Latency + "s_e2el", # Server E2EL + "s_decode", # Decode speed (tok/s) + "input_len", # Prefix cache hit tokens (= "Cached Tokens" in benchmark_serving) + "s_input_len", # Infer input length (total prompt tokens) + "output_len", # Output token length + ] + + def __init__(self, args: Optional[dict] = None): + for key, value in self._DEFAULTS.items(): + setattr(self, key, value) + if args: + for key, value in args.items(): + if key in self._DEFAULTS: + setattr(self, key, value) + self.percentile_values = [float(p.strip()) for p in self.percentiles.split(",") if p.strip()] + if self.metrics == "all": + self.selected_metrics = set(self._ALL_METRICS) + else: + self.selected_metrics = {m.strip() for m in self.metrics.split(",") if m.strip()} + class FDConfig: """ @@ -1861,6 +1996,7 @@ def __init__( tool_parser: str = None, test_mode=False, routing_replay_config: Optional[RoutingReplayConfig] = None, + benchmark_metrics_config=None, deploy_modality: DeployModality = DeployModality.MIXED, ): self.model_config: ModelConfig = model_config # type: ignore @@ -1878,68 +2014,39 @@ def __init__( self.structured_outputs_config: StructuredOutputsConfig = structured_outputs_config self.router_config: RouterConfig = router_config self.routing_replay_config = routing_replay_config + self.benchmark_metrics_config = benchmark_metrics_config self.deploy_modality: DeployModality = deploy_modality + # Initialize cuda graph capture list max_capture_shape = self.scheduler_config.max_num_seqs - if self.speculative_config is not None and self.speculative_config.method in [ - SpecMethod.MTP, - SpecMethod.SUFFIX, - ]: - max_capture_shape = self.scheduler_config.max_num_seqs * ( - self.speculative_config.num_speculative_tokens + 1 - ) - assert max_capture_shape % 2 == 0, "CUDAGraph only supports capturing even token nums in MTP scenarios." - self.graph_opt_config.real_bsz_to_captured_size = { - k: 0 for k in range(1, self.scheduler_config.max_num_seqs + 1) - } if self.graph_opt_config.cudagraph_only_prefill: max_capture_shape = 512 else: - max_capture_shape = ( - max_capture_shape if self.speculative_config is not None else min(512, max_capture_shape) - ) + max_capture_shape = min(512, max_capture_shape) max_capture_shape_prefill = graph_opt_config.max_capture_shape_prefill if self.graph_opt_config.cudagraph_capture_sizes is None: - dec_token_per_query_per_step = ( - self.speculative_config.num_speculative_tokens + 1 - if self.speculative_config is not None and self.speculative_config.method is not None - else 1 - ) self.graph_opt_config._set_cudagraph_sizes( max_capture_size=max_capture_shape, max_capture_shape_prefill=max_capture_shape_prefill, - dec_token_per_query_per_step=dec_token_per_query_per_step, ) - if self.speculative_config is not None and self.speculative_config.method is not None: - real_bsz_to_captured_size = {} - for capture_size in self.graph_opt_config.cudagraph_capture_sizes: - dummy_batch_size = int(capture_size / (self.speculative_config.num_speculative_tokens + 1)) - real_bsz_to_captured_size[dummy_batch_size] = capture_size - def expand_bsz_map(real_bsz_to_captured_size): - """ - Expand a sparse batch size mapping into a dense one. - - Args: - real_bsz_to_captured_size (dict): Sparse batch size to capture size mapping. - Returns: - dict: Dense batch size to capture size mapping. - """ - sorted_items = sorted(real_bsz_to_captured_size.items()) - result = {} - prev_bsz = 0 - for curr_bsz, cap in sorted_items: - for bsz in range(prev_bsz + 1, curr_bsz + 1): - result[bsz] = cap - prev_bsz = curr_bsz - return result - - self.graph_opt_config.real_bsz_to_captured_size = expand_bsz_map(real_bsz_to_captured_size) self.graph_opt_config.init_with_cudagrpah_size( max_capture_size=max_capture_shape, max_capture_shape_prefill=max_capture_shape_prefill, + num_speculative_tokens=( + self.speculative_config.num_speculative_tokens + if ( + self.speculative_config is not None + and self.speculative_config.method + in [ + SpecMethod.MTP, + SpecMethod.SUFFIX, + ] + ) + else 0 + ), ) self.tokenizer = tokenizer @@ -1980,6 +2087,7 @@ def expand_bsz_map(real_bsz_to_captured_size): int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 0 and self.model_config is not None and self.model_config.enable_mm + and self.deploy_modality != DeployModality.TEXT ): self.max_prefill_batch = 1 # TODO:当前V0多模prefill阶段只支持并行度为1,待优化 else: @@ -2007,18 +2115,32 @@ def expand_bsz_map(real_bsz_to_captured_size): and self.router_config and self.router_config.router ): - # For RL scenario: version.yaml will be required for models in future releases. + # For RL scenario, version.yaml is required for models # Temporarily enforce use router to be enabled. self.model_config.read_model_version() self.read_from_config() self.postprocess() - self.init_cache_info() + self.init_pd_info() if test_mode: return self.check() # self.print() # NOTE: it's better to explicitly call .print() when FDConfig is initialized + @property + def enable_mm_runtime(self) -> bool: + return ( + self.model_config is not None + and self.model_config.enable_mm + and self.deploy_modality != DeployModality.TEXT + ) + + @property + def enable_rope_3d_runtime(self) -> bool: + return self.enable_mm_runtime and ( + getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False) + ) + def _disable_sequence_parallel_moe_if_needed(self, mode_name): if self.parallel_config.use_sequence_parallel_moe and self.graph_opt_config.use_cudagraph: self.parallel_config.use_sequence_parallel_moe = False @@ -2036,6 +2158,9 @@ def postprocess(self): # The first moe layer id of GLM4.5 model self.model_config.moe_layer_start_index = self.model_config.first_k_dense_replace + if self.routing_replay_config is not None: + self.routing_replay_config.postprocess(self.model_config) + if self.parallel_config.tensor_parallel_size <= self.worker_num_per_node or self.node_rank == 0: self.is_master = True self.master_ip = "0.0.0.0" @@ -2047,7 +2172,10 @@ def postprocess(self): if self.scheduler_config.max_num_batched_tokens is None: if int(envs.ENABLE_V1_KVCACHE_SCHEDULER): - self.scheduler_config.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM + if int(envs.FD_DISABLE_CHUNKED_PREFILL): + self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len + else: + self.scheduler_config.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM else: if self.cache_config.enable_chunked_prefill: self.scheduler_config.max_num_batched_tokens = 2048 @@ -2057,9 +2185,21 @@ def postprocess(self): if self.long_prefill_token_threshold == 0: self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04) + if ( + self.model_config is not None + and self.model_config.enable_mm + and self.deploy_modality == DeployModality.TEXT + ): + if getattr(self.model_config, "rope_3d", False) or getattr(self.model_config, "use_3d_rope", False): + logger.info( + "Deploy modality is text; forcing the multimodal-capable model onto the 2D RoPE runtime path." + ) + setattr(self.model_config, "rope_3d", False) + setattr(self.model_config, "use_3d_rope", False) + self.cache_config.max_block_num_per_seq = int(self.model_config.max_model_len // self.cache_config.block_size) self.cache_config.postprocess(self.get_max_chunk_tokens(), self.scheduler_config.max_num_seqs) - if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER: + if self.model_config is not None and self.enable_mm_runtime and not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.cache_config.enable_prefix_caching = False if ( self.structured_outputs_config is not None @@ -2085,7 +2225,7 @@ def postprocess(self): f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]" ) - if self.model_config.enable_mm: + if self.enable_mm_runtime: if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0: self.cache_config.max_encoder_cache = self.scheduler_config.max_num_batched_tokens elif self.cache_config.max_encoder_cache != 0: @@ -2127,6 +2267,25 @@ def postprocess(self): self.speculative_config.num_speculative_tokens = 1 self.speculative_config.num_model_steps = 1 + if self.speculative_config is not None and self.speculative_config.method is not None: + num_spec_tokens = self.speculative_config.num_speculative_tokens + # For speculative, enlarge the threshold to trigger block preallocation earlier, + # since each step consumes num_spec_tokens + 1 slots at once + old_prealloc_threshold = self.cache_config.prealloc_dec_block_slot_num_threshold + prealloc_dec_block_slot = self.cache_config.prealloc_dec_block_slot_num_threshold * (num_spec_tokens + 1) + max_prealloc_dec_block_slot = max( + 0, self.cache_config.block_size * self.cache_config.enc_dec_block_num - 1 + ) + self.cache_config.prealloc_dec_block_slot_num_threshold = min( + prealloc_dec_block_slot, max_prealloc_dec_block_slot + ) + logger.info( + f"prealloc_dec_block_slot_num_threshold updated: {old_prealloc_threshold} -> " + f"{self.cache_config.prealloc_dec_block_slot_num_threshold} " + f"(num_spec_tokens={num_spec_tokens}, block_size={self.cache_config.block_size}, " + f"enc_dec_block_num={self.cache_config.enc_dec_block_num})" + ) + if self.scheduler_config.splitwise_role == "mixed": self._disable_sequence_parallel_moe_if_needed("Mixed") self.model_config.moe_phase = MoEPhase(phase="prefill") @@ -2201,9 +2360,15 @@ def check(self): assert ( self.scheduler_config.max_num_seqs >= 1 ), f"max_num_seqs: {self.scheduler_config.max_num_seqs} should be larger than 1" - assert self.scheduler_config.max_num_batched_tokens >= self.scheduler_config.max_num_seqs, ( + tokens_per_seq = ( + (getattr(self.speculative_config, "num_speculative_tokens", 0) + 1) + if self.speculative_config is not None and self.speculative_config.method is not None + else 1 + ) + assert self.scheduler_config.max_num_batched_tokens >= self.scheduler_config.max_num_seqs * tokens_per_seq, ( f"max_num_batched_tokens: {self.scheduler_config.max_num_batched_tokens} " - f"should be larger than or equal to max_num_seqs: {self.scheduler_config.max_num_seqs}" + f"should be larger than or equal to max_num_seqs: {self.scheduler_config.max_num_seqs} " + f"* tokens_per_seq: {tokens_per_seq}" ) assert ( self.scheduler_config.max_num_batched_tokens @@ -2298,6 +2463,33 @@ def check(self): " CUDA 12.x → pip install cuda-python==12.*\n" ) + if self.benchmark_metrics_config is not None: + cfg = self.benchmark_metrics_config + assert isinstance( + cfg.enable, bool + ), f"BenchmarkMetricsConfig: 'enable' must be a bool, got {type(cfg.enable).__name__}" + assert ( + isinstance(cfg.window_size, int) and cfg.window_size >= 0 + ), f"BenchmarkMetricsConfig: 'window_size' must be a non-negative integer, got {cfg.window_size!r}" + assert cfg.window_mode in ( + "sliding", + "tumbling", + ), f"BenchmarkMetricsConfig: 'window_mode' must be 'sliding' or 'tumbling', got {cfg.window_mode!r}" + assert ( + isinstance(cfg.percentiles, str) and cfg.percentiles.strip() + ), f"BenchmarkMetricsConfig: 'percentiles' must be a non-empty string, got {cfg.percentiles!r}" + for p in cfg.percentile_values: + assert 0 <= p <= 100, f"BenchmarkMetricsConfig: percentile value {p} out of range [0, 100]" + assert ( + isinstance(cfg.metrics, str) and cfg.metrics.strip() + ), f"BenchmarkMetricsConfig: 'metrics' must be a non-empty string, got {cfg.metrics!r}" + if cfg.metrics != "all": + invalid = cfg.selected_metrics - set(BenchmarkMetricsConfig._ALL_METRICS) + assert not invalid, ( + f"BenchmarkMetricsConfig: unknown metric(s): {invalid}. " + f"Valid metrics: {BenchmarkMetricsConfig._ALL_METRICS}" + ) + def print(self): """ print all config @@ -2320,18 +2512,17 @@ def print(self): logger.info("{:<20}:{:<6}{}".format(k, "", v)) logger.info("=============================================================") - def init_cache_info(self): + def init_pd_info(self): """ - initialize cache info + initialize info for pd deployment """ - # TODO: group the splitiwse params # There are two methods for splitwise deployment: # 1. v0 splitwise_scheduler or dp_scheduler - # 2. v1 local_scheduler + router + # 2. v1 local_scheduler + router (optional) self.splitwise_version = None if self.scheduler_config.name in ("splitwise", "dp"): self.splitwise_version = "v0" - elif self.scheduler_config.name == "local" and self.router_config and self.router_config.router: + elif self.scheduler_config.name == "local": self.splitwise_version = "v1" # the information for registering this server to router or splitwise_scheduler @@ -2376,7 +2567,7 @@ def reset_value(cls, value_name, key): ) reset_value(self.cache_config, "cache_dtype", "infer_model_dtype") - def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): + def get_max_chunk_tokens(self, mm_max_tokens_per_item=None) -> int: """ get max chunk tokens @@ -2389,10 +2580,16 @@ def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): if paddle.is_compiled_with_xpu(): num_tokens = self.scheduler_config.max_num_batched_tokens else: - num_tokens = self.scheduler_config.max_num_seqs + # In MTP scenario, each sequence generates (num_speculative_tokens + 1) tokens per step + mtp_steps = ( + (getattr(self.speculative_config, "num_speculative_tokens", 0) + 1) + if self.speculative_config is not None and self.speculative_config.method is not None + else 1 + ) + num_tokens = self.scheduler_config.max_num_seqs * mtp_steps else: num_tokens = self.scheduler_config.max_num_batched_tokens - if mm_max_tokens_per_item is not None and self.deploy_modality != DeployModality.TEXT: + if self.enable_mm_runtime and mm_max_tokens_per_item is not None: max_mm_tokens = max( mm_max_tokens_per_item.get("image", 0), mm_max_tokens_per_item.get("video", 0), diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index ff0965c56bb..a96a6ace489 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -188,6 +188,10 @@ class EngineArgs: """ Configuration for speculative execution. """ + benchmark_metrics_config: Optional[Dict[str, Any]] = None + """ + Configuration for in-process benchmark metrics logger. + """ dynamic_load_weight: bool = False """ dynamic load weight @@ -264,7 +268,7 @@ class EngineArgs: """ Flag to enable prefix caching. """ - enable_output_caching: bool = True + enable_output_caching: bool = False """ Flag to enable kv cache for output tokens, only valid in V1 scheduler. """ @@ -274,6 +278,11 @@ class EngineArgs: Flag to disable the custom all-reduce kernel. """ + enable_flashinfer_allreduce_fusion: bool = False + """ + Flag to enable all reduce fusion kernel in flashinfer. + """ + use_internode_ll_two_stage: bool = False """ Flag to use the internode_ll_two_stage kernel. @@ -332,6 +341,11 @@ class EngineArgs: Chunk size of moe input. """ + enable_moe_scores_elementwise_fuse: bool = False + """ + Flag to enable fused elementwise in get_moe_scores. Default is False (disabled). + """ + cache_transfer_protocol: str = "ipc,rdma" """ Protocol to use for cache transfer. @@ -460,6 +474,14 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + enable_keep_sampling_mask: bool = False + """ + When enabled, the server returns a sparse index list for each generated token, indicating + which vocabulary positions were retained after top_p/top_k sampling, and streams it to + the client. In MTP (multi-token prediction) scenarios this field is a List[List[int]], + where each inner list contains the retained vocabulary indices for a predicted token. + """ + max_logprobs: int = 20 """ Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the @@ -571,12 +593,7 @@ def __post_init__(self): and not current_platform.is_maca() ): self.enable_prefix_caching = False - if ( - not current_platform.is_cuda() - or (self.speculative_config is not None and self.enable_logprob) - or self.splitwise_role == "prefill" - or self.dynamic_load_weight - ): + if not current_platform.is_cuda() or self.splitwise_role == "prefill": self.enable_overlap_schedule = False if self.enable_logprob: if not current_platform.is_cuda() and not current_platform.is_xpu(): @@ -592,10 +609,15 @@ def __post_init__(self): raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1") if self.splitwise_role != "mixed": - if self.scheduler_name == "local" and self.router is None: + if self.scheduler_name == "splitwise": raise ValueError( - f"When using {self.splitwise_role} role and the {self.scheduler_name} " - f"scheduler, please provide --router argument." + "Setting scheduler_name as splitwise is not supported in pd deployment, " + "please use router as scheduler." + ) + if self.scheduler_name == "local" and self.router is None: + console_logger.warning( + f"Running {self.splitwise_role} role with {self.scheduler_name} " + f"scheduler without --router. Router registration and request routing will be disabled." ) if not ( @@ -619,6 +641,10 @@ def __post_init__(self): "kvcache_storage_backend is only supported when ENABLE_V1_KVCACHE_SCHEDULER=1" ) + if envs.FD_PD_TRANSFER_VIA_STORAGE: + if self.kvcache_storage_backend is None: + raise ValueError("Must set kvcache_storage_backend when FD_PD_TRANSFER_VIA_STORAGE=1") + valid_model_impls = ["auto", "fastdeploy", "paddleformers"] if self.model_impl not in valid_model_impls: raise NotImplementedError( @@ -829,6 +855,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.speculative_config, help="Configuration for speculative execution.", ) + model_group.add_argument( + "--benchmark-metrics-config", + type=json.loads, + default=EngineArgs.benchmark_metrics_config, + help="Configuration for in-process benchmark metrics logger. " + "Pass '{}' for defaults or a JSON with keys: " + "window_size (int, 0=all requests), " + "percentiles (str, e.g. '50,90,95,99'), " + "metrics (str, 'all' or comma-separated subset).", + ) model_group.add_argument( "--dynamic-load-weight", action="store_true", @@ -857,11 +893,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--quantization", type=parse_quantization, default=EngineArgs.quantization, - help="Quantization name for the model, currently support " - "'wint8', 'wint4'," - "default is None. The priority of this configuration " - "is lower than that of the config file. " - "More complex quantization methods need to be configured via the config file.", + help="Quantization config for the model. Can be a simple method name " + "(e.g. 'wint8', 'wint4') or a full JSON quantization_config string " + '(e.g. \'{"quantization": "mix_quant", "kv_cache_quant_type": "block_wise_fp8", ' + '"dense_quant_type": "block_wise_fp8", "moe_quant_type": "block_wise_fp8"}\'). ' + "When a JSON config is provided, it is processed the same way as " + "quantization_config in the model's config.json. " + "If both CLI and config.json specify quantization_config, " + "config.json takes higher priority. Default is None.", ) model_group.add_argument( "--graph-optimization-config", @@ -893,6 +932,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--enable-keep-sampling-mask", + action="store_true", + default=EngineArgs.enable_keep_sampling_mask, + help=( + "Enable output of sampling mask as a sparse index list over the vocabulary. " + "For non-MTP decoding, this is a list[int] per token step indicating which " + "vocabulary indices were kept after top_p/top_k sampling. " + "For MTP decoding, this is a list[list[int]] per token step, where each inner " + "list corresponds to one MTP group." + ), + ) model_group.add_argument( "--max-logprobs", type=int, @@ -977,6 +1028,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.disable_custom_all_reduce, help="Flag to disable custom all-reduce.", ) + parallel_group.add_argument( + "--enable-flashinfer-allreduce-fusion", + action="store_true", + default=EngineArgs.enable_flashinfer_allreduce_fusion, + help="Flag to enable all reduce fusion kernel in flashinfer.", + ) parallel_group.add_argument( "--use-internode-ll-two-stage", action="store_true", @@ -1356,7 +1413,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_overlap_schedule, help="Enable overlapping schedule.", ) - + scheduler_group.add_argument( + "--enable-moe-scores-elementwise-fuse", + action="store_true", + default=EngineArgs.enable_moe_scores_elementwise_fuse, + help="Enable fused elementwise in get_moe_scores for MoE routing.", + ) model_group.add_argument( "--deploy-modality", type=str, @@ -1387,6 +1449,14 @@ def create_speculative_config(self) -> SpeculativeConfig: return SpeculativeConfig(speculative_args) + def create_benchmark_metrics_config(self): + """Create BenchmarkMetricsConfig if --benchmark-metrics-config is provided.""" + if self.benchmark_metrics_config is None: + return None + from fastdeploy.config import BenchmarkMetricsConfig + + return BenchmarkMetricsConfig(self.benchmark_metrics_config) + def create_scheduler_config(self) -> SchedulerConfig: """ Create and return a SchedulerConfig object based on the current settings. @@ -1466,6 +1536,7 @@ def create_engine_config(self) -> FDConfig: self.tensor_parallel_size = model_cfg.tensor_parallel_size speculative_cfg = self.create_speculative_config() + benchmark_metrics_cfg = self.create_benchmark_metrics_config() if not self.enable_chunked_prefill: if (current_platform.is_cuda() or current_platform.is_maca()) and self.splitwise_role == "mixed": # default enable chunked prefill @@ -1477,7 +1548,11 @@ def create_engine_config(self) -> FDConfig: if self.max_num_batched_tokens is None: if int(envs.ENABLE_V1_KVCACHE_SCHEDULER): - if current_platform.is_maca() or current_platform.is_iluvatar(): + if ( + int(envs.FD_DISABLE_CHUNKED_PREFILL) + or current_platform.is_maca() + or current_platform.is_iluvatar() + ): self.max_num_batched_tokens = self.max_model_len else: self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM @@ -1526,5 +1601,6 @@ def create_engine_config(self) -> FDConfig: plas_attention_config=plas_attention_config, early_stop_config=early_stop_cfg, routing_replay_config=routing_replay_config, + benchmark_metrics_config=benchmark_metrics_cfg, deploy_modality=DeployModality.from_str(self.deploy_modality), ) diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index 4afb3dc5c49..c06292ec981 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -294,6 +294,7 @@ def __init__(self, cfg, pid): cfg.limit_mm_per_prompt, cfg.mm_processor_kwargs, cfg.tool_parser, + enable_mm_runtime=cfg.enable_mm_runtime, ) # Create data processor self.data_processor = self.input_processor.create_processor() @@ -446,7 +447,7 @@ async def add_request( ) if envs.ZMQ_SEND_BATCH_DATA and self.connection_manager is not None: request["zmq_worker_pid"] = self.connection_manager.worker_pid - if self.cfg.model_config.enable_mm: + if self.cfg.enable_mm_runtime: self.request_client.send_pyobj(request) else: self.request_client.send_json(request) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 61a9914225b..429cce3b0b7 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -30,7 +30,6 @@ import time import traceback import weakref -from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -42,13 +41,12 @@ import fastdeploy.metrics.trace as tracing from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.config import FDConfig +from fastdeploy.engine.common_engine_prepare_mixin import EngineServicePrepareMixin from fastdeploy.engine.register_manager import RegisterManager from fastdeploy.engine.request import ( - CompletionOutput, ControlRequest, ControlResponse, Request, - RequestMetrics, RequestOutput, RequestStatus, RequestType, @@ -115,7 +113,7 @@ def _format_worker_launch_failure_message(log_dir: str) -> str: return message -class EngineService: +class EngineService(EngineServicePrepareMixin): """ Base class containing common engine functionality """ @@ -139,6 +137,7 @@ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): self.is_paused = False # pause request generation self._pause_cond = threading.Condition() + self._rejecting_new_requests = False # blocks new requests during abort drain self._ctrl_output_queues = {} self._ctrl_response_mailboxes = collections.defaultdict(collections.OrderedDict) @@ -197,10 +196,23 @@ def __init__(self, cfg: FDConfig, start_queue=True, use_async_llm=False): self.scheduler_metrics_logger = SchedulerMetricsLogger( enabled=True, dp_rank=self.cfg.parallel_config.local_data_parallel_id, + splitwise_role=self.cfg.scheduler_config.splitwise_role, ) self.resource_manager.scheduler_metrics_logger = self.scheduler_metrics_logger self.token_processor.set_scheduler_metrics_logger(self.scheduler_metrics_logger) + if self.cfg.benchmark_metrics_config is not None and self.cfg.benchmark_metrics_config.enable: + from fastdeploy.metrics.benchmark_metrics_logger import ( + BenchmarkMetricsLogger, + ) + + self.benchmark_metrics_logger = BenchmarkMetricsLogger( + config=self.cfg.benchmark_metrics_config, + log_dir=envs.FD_LOG_DIR, + dp_rank=self.cfg.parallel_config.local_data_parallel_id, + ) + self.token_processor.set_benchmark_logger(self.benchmark_metrics_logger) + self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1) for idx in range(1, self.cfg.max_num_partial_prefills + 1): self.partial_chunked_tokens[idx] = ( @@ -251,12 +263,13 @@ def start(self, async_llm_pid=None): self.start_worker_service(async_llm_pid) if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.insert_task_to_worker_thread = threading.Thread( - target=self._schedule_request_to_worker_v1, daemon=True - ) + self.prepare_request_thread = threading.Thread(target=self._prepare_request_v1, daemon=True) + self.prepare_request_thread.start() + self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker_v1, daemon=True) + self.schedule_request_thread.start() else: - self.insert_task_to_worker_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) - self.insert_task_to_worker_thread.start() + self.schedule_request_thread = threading.Thread(target=self._schedule_request_to_worker, daemon=True) + self.schedule_request_thread.start() self.token_processor.tasks_queue = self.engine_worker_queue self.token_processor.run() if self.cfg.scheduler_config.splitwise_role == "decode": @@ -330,6 +343,7 @@ def create_data_processor(self): self.cfg.limit_mm_per_prompt, self.cfg.mm_processor_kwargs, self.cfg.tool_parser, + enable_mm_runtime=self.cfg.enable_mm_runtime, ) self.data_processor = self.input_processor.create_processor() self.mm_max_tokens_per_item = self.data_processor.get_mm_max_tokens_per_item( @@ -371,15 +385,6 @@ def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进 create=True, ) - engine_forward_signal_data = np.zeros([1], dtype=np.int32) - self.engine_forward_signal = IPCSignal( - name="engine_forward_signal", - array=engine_forward_signal_data, - dtype=np.int32, - suffix=current_suffix, - create=True, - ) - # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 worker_healthy_live_recorded_time_array = np.zeros( shape=[min(self.cfg.worker_num_per_node, self.cfg.parallel_config.tensor_parallel_size)], dtype=np.int32 @@ -461,15 +466,19 @@ def start_worker_queue_service(self, start_queue): start queue service for engine worker communication """ if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: - address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port) + engine_worker_queue_address = (self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port) + engine_cache_queue_address = (self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port) else: - address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock" + engine_worker_queue_address = ( + f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.local_engine_worker_queue_port}.sock" + ) + engine_cache_queue_address = f"/dev/shm/fd_task_queue_{self.cfg.cache_config.local_cache_queue_port}.sock" if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0": if start_queue: - self.llm_logger.info(f"Starting engine worker queue server service at {address}") + self.llm_logger.info(f"Starting engine worker queue server service at {engine_worker_queue_address}") self.engine_worker_queue_server = EngineWorkerQueue( - address=address, + address=engine_worker_queue_address, is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, @@ -479,7 +488,7 @@ def start_worker_queue_service(self, start_queue): self.cfg.parallel_config.local_engine_worker_queue_port = ( self.engine_worker_queue_server.get_server_port() ) - address = ( + engine_worker_queue_address = ( self.cfg.master_ip, self.cfg.parallel_config.local_engine_worker_queue_port, ) @@ -489,17 +498,18 @@ def start_worker_queue_service(self, start_queue): f"Starting engine cache queue server service at {self.cfg.cache_config.local_cache_queue_port}" ) self.cache_task_queue = EngineCacheQueue( - address=(self.cfg.master_ip, self.cfg.cache_config.local_cache_queue_port), + address=engine_cache_queue_address, authkey=b"cache_queue_service", is_server=True, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=-1, local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) - self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port() + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + self.cfg.cache_config.local_cache_queue_port = self.cache_task_queue.get_server_port() self.engine_worker_queue = EngineWorkerQueue( - address=address, + address=engine_worker_queue_address, is_server=False, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, @@ -601,7 +611,7 @@ def insert_tasks(self, tasks: List[Request], current_id=-1): LoggingEventName.RESCHEDULED_INFERENCE_START, task.request_id, getattr(task, "user", "") ) if not is_prefill: - if not self.cfg.model_config.enable_mm: + if not self.cfg.enable_mm_runtime: self.update_requests_chunk_size(tasks) else: self.update_mm_requests_chunk_size(tasks) @@ -882,199 +892,19 @@ def _schedule_request_to_worker_v1(self): Insert tasks to worker with scheduler v1 (ENABLE_V1_KVCACHE_SCHEDULER=1). """ tracing.trace_set_thread_info("Scheduler Task to Work") - get_request_pool = ThreadPoolExecutor(max_workers=1) - is_fetching = False - - def _fetch_request(): - try: - with self._pause_cond: - self._pause_cond.wait_for(lambda: not self.is_paused) - nonlocal is_fetching - num_prefill_batch = min( - int(self.resource_manager.available_batch()), - self.cfg.max_prefill_batch, - ) - - if self.cfg.scheduler_config.splitwise_role != "mixed": - max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens - else: - max_num_batched_tokens = self.cfg.model_config.max_model_len - - available_blocks = self.cfg.cache_config.max_block_num_per_seq - tasks = self.scheduler.get_requests( - available_blocks=available_blocks, - block_size=self.cfg.cache_config.block_size, - reserved_output_blocks=0, # self.cfg.cache_config.enc_dec_block_num - max_num_batched_tokens=max_num_batched_tokens, - batch=num_prefill_batch, - ) - for task in tasks: - task.metrics.engine_get_req_time = time.time() - trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) - - if self.cfg.scheduler_config.splitwise_role == "decode": - # TODO: refine scheduler to remove this limitation - # Decode will process and schedule the request sent by prefill to engine, - # so the same request sent by the decode api server will be ignored - is_fetching = False - return - - if tasks: - self.llm_logger.debug( - f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" - ) - - if self.cfg.scheduler_config.splitwise_role == "prefill": - for task in tasks: - # start async preprocess - self.resource_manager.apply_async_preprocess(task) - need_delete_tasks = [] - if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - self.llm_logger.debug( - f"P has allocated resources and then ask D resource for request: {task.request_id}" - ) - task.metrics.ask_decode_resource_start_time = time.time() - while True: - self.split_connector.send_splitwise_tasks([task], task.idx) - status, msg = self.split_connector.check_decode_allocated(task) - if not status: - self.llm_logger.warning( - f"D failed to allocate resource for request {task.request_id}, try again." - ) - time.sleep(0.05) - else: - task.metrics.ask_decode_resource_finish_time = time.time() - break - self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") - else: - for task in tasks: - # assure can allocate block ids in P - while not self.resource_manager.preallocate_resource_in_p(task): - time.sleep(0.005) - - self.llm_logger.debug( - f"P has allocated resources and then ask D resource for req_id: {task.request_id}" - ) - task.metrics.ask_decode_resource_start_time = time.time() - self.split_connector.send_splitwise_tasks([task], task.idx) - - for task in tasks: - # assure fetch block ids from D - status, msg = self.split_connector.check_decode_allocated(task) - task.metrics.ask_decode_resource_finish_time = time.time() - if not status: - self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=500, - error_msg=msg, - ) - ] - ) - need_delete_tasks.append(task) - continue - for tmp_task in need_delete_tasks: - tasks.remove(tmp_task) - # release resource in P - self.resource_manager.pre_recycle_resource(tmp_task.request_id) - - # to send cache info to cache messager - if tasks: - need_check_req_ids = [task.request_id for task in tasks] - self.split_connector.send_cache_info_to_messager(tasks, 0) - # ensure cache tasks has sent to cache_messager - need_check_req_ids = [task.request_id for task in tasks] - finished_ids, delete_tasks_list = [], [] - while need_check_req_ids: - finished_ids.extend(self.engine_worker_queue.get_finished_add_cache_task_req()) - self.llm_logger.debug( - f"P has successfully sent cache infos to cache messager for requests: {finished_ids}" - ) - if finished_ids: - for task in tasks: - result = self.resource_manager.waiting_async_process(task) - if result is None: - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=task.error_code, - error_msg=task.error_message, - ) - ] - ) - need_check_req_ids.remove(task.request_id) - delete_tasks_list.append(task) - elif result is False: - if task.request_id in finished_ids: - need_check_req_ids.remove(task.request_id) - finished_ids.remove(task.request_id) - else: - time.sleep(0.001) - - for tmp_task in delete_tasks_list: - tasks.remove(tmp_task) - # release resource in P - self.resource_manager.pre_recycle_resource(tmp_task.request_id) - - # Fetch requests and add them to the scheduling queue - if tasks: - for task in tasks: - task.metrics.add_req_to_resource_manager_time = time.time() - trace_print( - LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "") - ) - if self.cfg.scheduler_config.splitwise_role == "prefill": - self.resource_manager.add_request_in_p(tasks) - self.llm_logger.info( - f"P add requests into running queue: {[task.request_id for task in tasks]}" - ) - else: - for task in tasks: - self.resource_manager.add_request(task) - is_fetching = False - except Exception as e: - self.llm_logger.error(f"fetching request error {e} {str(traceback.format_exc())}") - is_fetching = False while self.running: with self._pause_cond: self._pause_cond.wait_for(lambda: not self.is_paused) + try: - if not is_fetching: - # Check if the thread pool is still available to avoid submitting tasks to a shutdown thread pool. - try: - is_fetching = True - get_request_pool.submit(_fetch_request) - except RuntimeError as e: - if "shutdown" in str(e): - self.llm_logger.info("Thread pool shutdown detected, exiting scheduler loop") - break - else: - raise - if self.cfg.scheduler_config.splitwise_role != "mixed": - # Continue preprocessing incoming requests and accumulating them in the queue when forward pass not finished. - # Once the forward pass finishes, these accumulated requests can be scheduled in larger, - # more efficient batches. - if self.engine_worker_queue.exist_tasks() or self.engine_forward_signal.value[0] != 0: - time.sleep(0.001) - continue - else: - # In mixed, todo: optimze cache swap, to decouple swap from scheduler - if self.engine_worker_queue.exist_tasks(): - time.sleep(0.001) - continue + if self.engine_worker_queue.exist_tasks(): + time.sleep(0.001) + continue if hasattr(self.resource_manager, "scheduler_unhandled_request_num"): self.resource_manager.scheduler_unhandled_request_num = self._get_scheduler_unhandled_request_num() + # 2. Schedule requests tasks, error_tasks = self.resource_manager.schedule() @@ -1083,14 +913,18 @@ def _fetch_request(): if self.cfg.scheduler_config.splitwise_role == "decode": for task in tasks: if task.task_type == RequestType.PREEMPTED: - msg = f"{task.request_id} decode not enough blocks, need to be rescheduled." + msg = ( + f"PD Error: decode does not have enough blocks for " + f"preallocated request. req:{task.request_id} " + ) self.llm_logger.error(msg) + main_process_metrics.reschedule_req_num.inc() self.scheduler.put_results( [ RequestOutput( request_id=task.request_id, finished=True, - error_code=500, + error_code=502, error_msg=msg, ) ] @@ -1132,13 +966,6 @@ def _fetch_request(): elif not task.has_been_preempted_before: task.metrics.inference_start_time = time.time() self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz)) - else: - # When there are no actual tasks to schedule, send an empty task batch to EP workers. - # This helps EP workers barrier for syncing tasks not hang. - if self.cfg.parallel_config.enable_expert_parallel: - self.engine_worker_queue.put_tasks( - ([], self.resource_manager.real_bsz) - ) # Empty (as idle tasks for ep) # 4. Response error tasks if error_tasks: @@ -1152,11 +979,14 @@ def _fetch_request(): time.sleep(0.005) except RuntimeError as e: - if "cannot schedule new futures after shutdown" in str(e): - break + raise e except Exception as e: err_msg = "Error happened while insert task to engine: {}, {}.".format(e, str(traceback.format_exc())) self.llm_logger.error(err_msg) + # Failed to connect to engine worker queue, retry after 5 seconds + if self.engine_worker_queue.is_broken(): + self.llm_logger.error("Failed to connect to engine worker queue, retry after 5 seconds") + time.sleep(5) def _get_scheduler_unhandled_request_num(self) -> int: """ @@ -1218,7 +1048,7 @@ def _insert_zmq_task_to_scheduler(self): while self.running: try: block = True if len(added_requests) == 0 else False - if not self.cfg.model_config.enable_mm: + if not self.cfg.enable_mm_runtime: err, data = self.recv_request_server.receive_json_once(block) else: err, data = self.recv_request_server.receive_pyobj_once(block) @@ -1268,14 +1098,32 @@ def _insert_zmq_task_to_scheduler(self): self.request_worker_map[req_id_for_map] = worker_pid status_value = data.get("status", None) if status_value is not None and status_value == RequestStatus.ABORT.value: - req_id = data["request_id"] - self.llm_logger.info(f"Receive abort request, req_id: {req_id}") - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.resource_manager.add_abort_req_ids(req_id) + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.llm_logger.info("abort requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") + else: + abort_all = data.get("abort_all", False) + req_ids = data.get("req_ids", []) + if abort_all or req_ids: + target_req_ids = self._resolve_abort_targets(abort_all, req_ids) + self.llm_logger.info( + f"Receive abort_reqs, abort_all={abort_all}, " + f"input={len(req_ids)}, resolved={len(target_req_ids)}" + ) + self.resource_manager.add_abort_req_ids(target_req_ids) + else: + req_id = data.get("request_id", None) + if not req_id: + self.llm_logger.warning( + "Receive abort request without request_id, skip invalid abort message" + ) + continue + self.llm_logger.info(f"Receive abort request, req_id: {req_id}") + self.resource_manager.add_abort_req_ids(req_id) continue err_msg = None try: request = Request.from_dict(data) + request.metrics.scheduler_recv_req_time = time.time() main_process_metrics.requests_number.inc() trace_carrier = data.get("trace_carrier") @@ -1287,13 +1135,29 @@ def _insert_zmq_task_to_scheduler(self): trace_print(LoggingEventName.REQUEST_QUEUE_START, data["request_id"], data.get("user", "")) self.llm_logger.debug(f"Receive request from api server: {request}") - if self.is_paused: + if self.is_paused or self._rejecting_new_requests: self.llm_logger.warning(f"Engine is paused, drop request: {request}") self._send_error_response( request.request_id, "Request is aborted since LLM Engine is paused.", worker_pid=worker_pid, ) + # PD ghost prevention: notify decode side to recycle its + # scheduler entry, otherwise it would sit there as a ghost + # since prefill will never deliver any first token. + if ( + self.cfg.scheduler_config.splitwise_role == "prefill" + and getattr(request, "disaggregate_info", None) + and self.split_connector is not None + ): + try: + self.split_connector.send_drop_signal( + request.request_id, request.disaggregate_info + ) + except Exception as e: + self.llm_logger.warning( + f"Failed to send drop signal for {request.request_id}: {e}" + ) continue except Exception as e: self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}") @@ -1407,38 +1271,19 @@ def _control_pause(self, control_request: ControlRequest): if self.is_paused: self.llm_logger.info("Engine is already paused, no need to pause again.") return - self.is_paused = True + self._rejecting_new_requests = True + self.resource_manager.log_status() - self.llm_logger.info("Abort running requests.") + # Scheduling loop picks them up via _trigger_abort when they enter resource_manager + all_req_ids = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) + self.llm_logger.info(f"Pause: aborting {len(all_req_ids)} total requests.") + if all_req_ids: + self.resource_manager.add_abort_req_ids(all_req_ids) + self._wait_inflight_drained() + with self._pause_cond: + self.is_paused = True self.resource_manager.log_status() - # preempted all running reqs. preempted reqs will be append to ResourceManager.waiting queue - timeout, count = 60, 0 - while self.engine_worker_queue.exist_tasks(): - time.sleep(0.001) - count += 1 - if count >= timeout * 1000: - break - if count >= timeout * 1000: - error_msg = f"Emptying engine worker queue timed out after {timeout} seconds, worker may hanged!" - self.llm_logger.error(error_msg) - raise Exception(error_msg) - running_reqs = self.resource_manager.preempted_all() - if len(running_reqs) > 0: - self.llm_logger.info(f"Total {len(running_reqs)} requests need to be aborted.") - self.resource_manager.get_real_bsz() - self.engine_worker_queue.put_tasks((running_reqs, self.resource_manager.real_bsz)) - self.resource_manager.wait_worker_inflight_requests_finish(timeout=60) - # self.engine_worker_queue.clear_data() - self.token_processor.clear_data() - self.resource_manager.log_status() - - # abort inflight requests to user - inflight_requests = self.scheduler.get_inflight_requests() - self.llm_logger.info(f"Abort inflight requests (total {len(inflight_requests)}).") - for req in inflight_requests: - self._send_error_response(req.request_id, "Request is aborted since engine is paused.") - self.scheduler.reset() # pause cache transfer if self.cfg.cache_config.num_cpu_blocks > 0 or self.cfg.cache_config.kvcache_storage_backend: @@ -1459,6 +1304,66 @@ def _control_pause(self, control_request: ControlRequest): self.llm_logger.info("Successfully paused request generation.") return None + def _wait_inflight_drained(self): + """ + Wait until resource_manager.requests is completely empty. + Logs a warning and remove scheduler-only request every 30 seconds while waiting to help diagnose potential hangs. + """ + start_time = time.monotonic() + next_warn_time = start_time + 30 + GHOST_REAP_AFTER = 30.0 + + while self.resource_manager.requests or self.scheduler.requests: + now = time.monotonic() + + late_ids = list( + set(self.resource_manager.requests.keys()) + - self.resource_manager.waiting_abort_req_id_set + - self.resource_manager.to_be_aborted_req_id_set + ) + if late_ids: + self.resource_manager.add_abort_req_ids(late_ids) + self.llm_logger.info(f"Pause drain: late-arrived requests added to abort set: {late_ids}") + + if now - start_time >= GHOST_REAP_AFTER: + scheduler_only_ids = list( + set(self.scheduler.requests.keys()) - set(self.resource_manager.requests.keys()) + ) + if scheduler_only_ids: + ghost_outputs = [ + RequestOutput( + request_id=req_id, + finished=True, + error_code=499, + error_msg=(f"forced cleanup after {GHOST_REAP_AFTER}s"), + ) + for req_id in scheduler_only_ids + ] + self.scheduler.put_results(ghost_outputs) + self.llm_logger.warning( + f"Pause drain timeout: reaped {len(scheduler_only_ids)} " + f"scheduler-only ghost(s) after {GHOST_REAP_AFTER}s: " + f"{scheduler_only_ids}" + ) + # Reset to avoid re-reaping on the next tick + start_time = now + + if now >= next_warn_time: + self.llm_logger.warning( + "Still waiting for inflight requests to drain, " + f"elapsed: {now - start_time:.3f} seconds, " + f"resource_manager.requests: {len(self.resource_manager.requests)}, " + f"scheduler.requests: {len(self.scheduler.requests)}", + ) + next_warn_time = now + 30 + + time.sleep(0.005) + + self.llm_logger.info( + "All inflight requests drained, take time: %.3f seconds", + time.monotonic() - start_time, + ) + def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: """Control function for resuming request generation. @@ -1474,6 +1379,7 @@ def _control_resume(self, control_request: ControlRequest) -> Optional[dict]: self.llm_logger.info("Engine is not paused, no need to resume.") return None self.is_paused = False + self._rejecting_new_requests = False self._pause_cond.notify_all() # resume cache transfer @@ -1556,139 +1462,6 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d return responses - def _control_abort_requests(self, control_req: ControlRequest): - if not envs.ENABLE_V1_KVCACHE_SCHEDULER: - raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") - args = control_req.get_args() - abort_all = args.get("abort_all", False) - req_ids = args.get("req_ids", []) - matched_input_ids = set() - now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) - - # Step 1: Determine target request list - if abort_all: - # all requests in running + waiting - target_req_ids = now_reqs - else: - # filter out requests that actually exist - target_req_ids = [] - for rid in req_ids: - if rid in now_reqs: - target_req_ids.append(rid) - matched_input_ids.add(rid) - elif f"{rid}_0" in now_reqs: - target_req_ids.append(f"{rid}_0") - matched_input_ids.add(rid) - - if not target_req_ids: - return {"aborted": [], "not_found": req_ids if not abort_all else []} - - # Step 2: Collect partial results - aborted_info = [] - results = [] - for req_id in target_req_ids: - request = self.resource_manager.requests.get(req_id) - if request is None: - scheduled_req = self.scheduler.requests.get(req_id) - if scheduled_req is None: - continue - request = scheduled_req.raw - - partial_token_ids = list(request.output_token_ids) - - # Construct finished response with partial results - now = time.time() - abort_metrics = RequestMetrics( - arrival_time=request.metrics.arrival_time if request.metrics else now, - inference_start_time=request.metrics.inference_start_time if request.metrics else now, - engine_recv_latest_token_time=now, - engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, - request_start_time=request.metrics.arrival_time if request.metrics else now, - ) - result = RequestOutput( - request_id=req_id, - finished=True, - outputs=CompletionOutput( - index=0, - send_idx=len(partial_token_ids), - token_ids=[self.data_processor.eos_token_ids[0]], - ), - metrics=abort_metrics, - error_code=200, - error_msg="Aborted", - ) - results.append(result) - aborted_info.append( - { - "request_id": req_id, - "output_token_count": len(partial_token_ids), - } - ) - - # Step 3: Execute abort — add all requests to waiting_abort_req_id_set - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - for req_id in target_req_ids: - self.resource_manager.add_abort_req_ids(req_id) - time.sleep(0.0001) - if self.cfg.scheduler_config.splitwise_role != "prefill": - self._wait_abort_complete(target_req_ids) - - # Add results to scheduler, engine will have a thread calling get_results, - # then cleanup and call send_response to send to client. - # When client disconnects, send_response will automatically ignore - if self.cfg.scheduler_config.splitwise_role != "prefill": - try: - # self.send_response_server.send_response(req_id, [result]) - self.scheduler.put_results(results) - except Exception: - pass # client may have disconnected - - not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] - - return {"aborted": aborted_info, "not_found": not_found} - - def _wait_abort_complete(self, target_req_ids, stall_timeout=1): - """ - Wait for all abort requests to complete. - - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet - - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, - reset progress state if any, then continue monitoring - """ - target_set = set(target_req_ids) - prev_remaining_count = len(target_set) - last_progress_time = time.time() - remaining = target_set & self.resource_manager.get_reqs_in_aborting() - while remaining: - remaining = target_set & self.resource_manager.get_reqs_in_aborting() - if not remaining: - self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") - return - - current_count = len(remaining) - if current_count < prev_remaining_count: - # progress made: recycle_abort_task was called - self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") - last_progress_time = time.time() - prev_remaining_count = current_count - - if time.time() - last_progress_time > stall_timeout: - # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) - stuck = remaining & self.resource_manager.to_be_aborted_req_id_set - if stuck: - self.llm_logger.warning( - f"no abort progress for {stall_timeout}s, " - f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" - ) - for req_id in list(stuck): - self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") - self.resource_manager.recycle_abort_task(req_id) - # reset progress state - last_progress_time = time.time() - prev_remaining_count = current_count - len(stuck) - # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow - - time.sleep(0.005) - def _parse_tags(self, control_request: ControlRequest): """ Parse tags from control request. @@ -2039,6 +1812,31 @@ def _fetch_requests(): items = self.engine_worker_queue.get_disaggregated_tasks() for item in items: + msg_type = item[0] + + # PD pause race: P drops a request via paused gate and notifies us + # to recycle our scheduler entry (otherwise it becomes a ghost that + # blocks pause/abort drain forever). Synthesize a finished + # RequestOutput so it walks the normal put_results -> _recycle path + # and the client gets a 499 error response. + if msg_type == "decode_drop": + drop_outputs = [ + RequestOutput( + request_id=req_id, + finished=True, + error_code=499, + error_msg="Aborted: prefill dropped this request (paused gate)", + ) + for req_id in item[1] + ] + if drop_outputs: + self.scheduler.put_results(drop_outputs) + self.llm_logger.info( + "Decode recycled scheduler ghost(s) via P-side drop signal: " + f"{[r.request_id for r in drop_outputs]}" + ) + continue + tasks = item[1] if isinstance(tasks[0], Request): self.llm_logger.debug( @@ -2062,6 +1860,7 @@ def _process_allocate_resource_requests(): processed_indices = [] for idx, task in enumerate(allocate_resource_requests): is_success = False + trace_print(LoggingEventName.DECODE_PROCESS_PREALLOCATE_REQUEST_START, task.request_id, task.user) if envs.ENABLE_V1_KVCACHE_SCHEDULER: if self.resource_manager.preallocate_resource_in_d(task): @@ -2071,6 +1870,7 @@ def _process_allocate_resource_requests(): self.llm_logger.debug(f"D has successfully sent cache infos for task {task.request_id}") processed_indices.append(idx) is_success = True + main_process_metrics.decode_preallocated_req_num.inc() else: if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len): self.llm_logger.debug(f"D Resource available, processing task {task.request_id}") @@ -2090,24 +1890,54 @@ def _process_allocate_resource_requests(): break for idx in sorted(processed_indices, reverse=True): + trace_print( + LoggingEventName.DECODE_PROCESS_PREALLOCAT_REQUEST_END, + allocate_resource_requests[idx].request_id, + allocate_resource_requests[idx].user, + ) allocate_resource_requests.pop(idx) def _process_prefilled_requests(): nonlocal prefilled_request_ouputs ready_request_outputs = [] waiting_request_outputs = [] + ghost_request_outputs = [] for req_output in prefilled_request_ouputs: - if hasattr(self.scheduler, "has_request") and not self.scheduler.has_request(req_output.request_id): + req_id = req_output.request_id + if hasattr(self.scheduler, "has_request") and not self.scheduler.has_request(req_id): + if ( + req_id in self.resource_manager.waiting_abort_req_id_set + or req_id in self.resource_manager.to_be_aborted_req_id_set + ): + ghost_request_outputs.append(req_output) + continue # ensure the api_server and scheduler in decode have # received the request sent by the client waiting_request_outputs.append(req_output) continue req_output.finished = False ready_request_outputs.append(req_output) + trace_print(LoggingEventName.DECODE_PROCESS_PREFILLED_REQUEST_START, req_output.request_id, "") self.llm_logger.debug(f"there are enough resource for prefilled request: {req_output.request_id}") prefilled_request_ouputs = waiting_request_outputs + + for req_output in ghost_request_outputs: + req_id = req_output.request_id + self.llm_logger.warning( + f"Pause drain: reaping prefilled-output ghost {req_id} " + "(scheduler never registered, marked for abort -- breaks deadlock)" + ) + try: + self.resource_manager.pre_recycle_resource(req_id) + except Exception as e: + self.llm_logger.warning(f"pre_recycle_resource({req_id}) failed: {e}") + self.resource_manager.waiting_abort_req_id_set.discard(req_id) + self.resource_manager.to_be_aborted_req_id_set.discard(req_id) + if req_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[req_id] + if self.cfg.splitwise_version == "v1": # decode return first token to client self.scheduler.put_results(ready_request_outputs) @@ -2117,6 +1947,8 @@ def _process_prefilled_requests(): else: for req_output in ready_request_outputs: request_id = req_output.request_id + main_process_metrics.decode_preallocated_req_num.dec() + trace_print(LoggingEventName.DECODE_PROCESS_PREFILLED_REQUEST_END, request_id, "") if envs.FD_ENABLE_INTERNAL_ADAPTER and not req_output.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue self.llm_logger.warning(f"{request_id} need not decode after first token") @@ -2130,6 +1962,7 @@ def _process_prefilled_requests(): self.llm_logger.warning( f"{request_id} prefill failed with msg:{req_output.error_msg}, recycle resource." ) + main_process_metrics.failed_recv_first_token_req_num.inc() self.resource_manager.pre_recycle_resource(request_id) if request_id in self.token_processor.tokens_counter: del self.token_processor.tokens_counter[request_id] @@ -2138,6 +1971,32 @@ def _process_prefilled_requests(): self.token_processor.tokens_counter[request_id] = 1 if envs.FD_ENABLE_INTERNAL_ADAPTER: # first token sent by D instance self.scheduler.put_results([req_output]) + + # Storage pool mode: D reads cache from storage before adding to running queue + if envs.FD_PD_TRANSFER_VIA_STORAGE: + request = self.resource_manager.requests[request_id] + self.llm_logger.info(f"[PD Storage] D reading cache from storage, request_id: {request_id}") + storage_block_ids = self.resource_manager.cache_manager.read_cache_from_storage_for_pd(request) + if not storage_block_ids: + self.llm_logger.error( + f"[PD Storage] D failed to read cache from storage, " f"request_id: {request_id}" + ) + self.resource_manager.pre_recycle_resource(request_id) + if request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[request_id] + req_output.error_code = 502 + req_output.error_msg = ( + f"PD Storage Error: D failed to read all blocks from storage, " + f"request_id: {request_id}" + ) + req_output.finished = True + self.scheduler.put_results([req_output]) + continue + self.llm_logger.info( + f"[PD Storage] D successfully read cache from storage, " + f"request_id: {request_id}, blocks: {len(storage_block_ids)}" + ) + self.resource_manager.add_prefilled_request(req_output) self.llm_logger.info(f"D has successfully added prefilled request, {request_id}") @@ -2378,7 +2237,7 @@ def _setting_environ_variables(self): if self.cfg.scheduler_config.splitwise_role == "prefill": variables["FLAGS_fmt_write_cache_completed_signal"] = 1 - if self.cfg.model_config.enable_mm: + if self.cfg.enable_mm_runtime: variables["FLAGS_max_partition_size"] = 1024 command_prefix = "" @@ -2509,6 +2368,8 @@ def _start_worker_service(self): "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, "enable_entropy": self.cfg.model_config.enable_entropy, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, + "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, } for worker_flag, value in worker_store_true_flag.items(): if value: @@ -2547,10 +2408,47 @@ def _stop_profile(self): num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.resource_manager.reset_cache_config(self.cfg.cache_config) + + # Create RoutingCacheManager (SharedMemory) after num_gpu_blocks is known + self.routing_cache_manager = None + if self.cfg.routing_replay_config.enable_routing_replay: + self._init_routing_cache_manager(num_gpu_blocks) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": device_ids = self.cfg.parallel_config.device_ids.split(",") self.cache_manager_processes = self.start_cache_service(device_ids, self.ipc_signal_suffix) + def _init_routing_cache_manager(self, num_gpu_blocks: int): + """Create RoutingCacheManager (includes SharedMemory host buffer) after profiling.""" + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingCacheManager, + RoutingHostBufferView, + ) + + self.routing_cache_manager = RoutingCacheManager( + fd_config=self.cfg, + num_gpu_blocks=num_gpu_blocks, + ) + + # Pass routing_cache_manager to TokenProcessor for local/rdma store dispatch + self.token_processor.routing_cache_manager = self.routing_cache_manager + + # Set routing_host_view on resource_manager for PD disaggregation (D side) + if hasattr(self, "resource_manager") and hasattr(self.resource_manager, "routing_host_view"): + rrc = self.cfg.routing_replay_config + dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + max_num_kv_tokens = num_gpu_blocks * self.cfg.cache_config.block_size + shape = (max_num_kv_tokens, rrc.num_moe_layers, rrc.moe_top_k) + try: + self.resource_manager.routing_host_view = RoutingHostBufferView( + shape=shape, dtype=rrc.routing_dtype, shm_name=shm_name + ) + except FileNotFoundError: + self.llm_logger.warning( + f"[R3] RoutingHostBuffer SharedMemory {shm_name} not found for resource_manager" + ) + def check_health(self, time_interval_threashold=30): """ Check the health of the model server by checking whether all workers are alive. @@ -2684,3 +2582,21 @@ def detect_thread(): except Exception: pass return True + + def _resolve_abort_targets(self, abort_all, req_ids): + """ + Resolve abort target request IDs. + """ + now_reqs = set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys()) + self.llm_logger.debug(f"now_reqs: {now_reqs}") + + if abort_all: + return list(now_reqs) + + target_req_ids = [] + for rid in req_ids: + if rid in now_reqs: + target_req_ids.append(rid) + elif f"{rid}_0" in now_reqs: + target_req_ids.append(f"{rid}_0") + return target_req_ids diff --git a/fastdeploy/engine/common_engine_prepare_mixin.py b/fastdeploy/engine/common_engine_prepare_mixin.py new file mode 100644 index 00000000000..c6ea2f3ee9d --- /dev/null +++ b/fastdeploy/engine/common_engine_prepare_mixin.py @@ -0,0 +1,282 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import threading +import time +import traceback + +import fastdeploy.metrics.trace as tracing +from fastdeploy.engine.request import RequestOutput +from fastdeploy.metrics.metrics import main_process_metrics +from fastdeploy.trace.constants import LoggingEventName +from fastdeploy.trace.trace_logger import print as trace_print +from fastdeploy.utils import envs + + +class EngineServicePrepareMixin: + def _fetch_request_mixed(self) -> bool: + """Fetch and prepare requests for a mixed instance. Returns True if tasks were fetched.""" + # FIXME: to validate if it's necessary for avoiding error when enable mtp + if len(self.resource_manager.waiting) > 0: + return False + + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.model_config.max_model_len + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + if not tasks: + return False + + for task in tasks: + task.metrics.engine_get_req_time = time.time() + trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) + + self.llm_logger.debug( + f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" + ) + + for task in tasks: + task.metrics.add_req_to_resource_manager_time = time.time() + trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")) + self.resource_manager.add_request(task) + + return True + + def _fetch_request_decode(self) -> bool: + """Consume scheduler queue for decode instance to prevent memory accumulation. + Returns True if tasks were consumed.""" + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + # Tasks are intentionally discarded - decode receives requests via _decode_process_splitwise_requests + return len(tasks) > 0 + + def _fetch_request_prefill(self) -> bool: + """Fetch and prepare requests for a prefill instance. Returns True if tasks were fetched.""" + num_prefill_batch = min( + int(self.resource_manager.available_batch()), + self.cfg.max_prefill_batch, + ) + max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens + available_blocks = self.cfg.cache_config.max_block_num_per_seq + + tasks = self.scheduler.get_requests( + available_blocks=available_blocks, + block_size=self.cfg.cache_config.block_size, + reserved_output_blocks=0, + max_num_batched_tokens=max_num_batched_tokens, + batch=num_prefill_batch, + ) + if not tasks: + return False + + for task in tasks: + task.metrics.engine_get_req_time = time.time() + trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", "")) + + self.llm_logger.debug( + f"Engine has fetched tasks from {self.scheduler.__class__.__name__}: {[task.request_id for task in tasks]}" + ) + + # Start async preprocess for all tasks in this batch + for task in tasks: + self.resource_manager.apply_async_preprocess(task) + + # P-side resource preallocation + D-side coordination + failed_tasks = [] + if envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + self.llm_logger.debug( + f"P has allocated resources and then ask D resource for request: {task.request_id}" + ) + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")) + task.metrics.ask_decode_resource_start_time = time.time() + while True: + self.split_connector.send_splitwise_tasks([task], task.idx) + status, msg = self.split_connector.check_decode_allocated(task) + if status: + task.metrics.ask_decode_resource_finish_time = time.time() + trace_print( + LoggingEventName.ASK_DECODE_RESOURCE_END, + task.request_id, + getattr(task, "user", ""), + ) + break + else: + self.llm_logger.warning( + f"D failed to allocate resource for request {task.request_id}, try again." + ) + time.sleep(0.05) + + self.llm_logger.debug(f"D has allocated resource for request: {task.request_id}") + else: + for task in tasks: + # assure can allocate block ids in P + while not self.resource_manager.preallocate_resource_in_p(task): + time.sleep(0.005) + + self.llm_logger.debug( + f"P has allocated resources and then ask D resource for req_id: {task.request_id}" + ) + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_START, task.request_id, getattr(task, "user", "")) + task.metrics.ask_decode_resource_start_time = time.time() + self.split_connector.send_splitwise_tasks([task], task.idx) + + for task in tasks: + # assure fetch block ids from D + status, msg = self.split_connector.check_decode_allocated(task) + task.metrics.ask_decode_resource_finish_time = time.time() + trace_print(LoggingEventName.ASK_DECODE_RESOURCE_END, task.request_id, getattr(task, "user", "")) + if not status: + error_msg = ( + f"PD Error: prefill failed to apply for resource from decode, " + f"req: {task.request_id}, msg:{msg}." + ) + self.llm_logger.error(error_msg) + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=error_msg, + ) + ] + ) + main_process_metrics.reschedule_req_num.inc() + failed_tasks.append(task) + + for tmp_task in failed_tasks: + tasks.remove(tmp_task) + self.resource_manager.pre_recycle_resource(tmp_task.request_id) + + # Check and wait async preprocess + if tasks: + need_check_req_ids = [task.request_id for task in tasks] + failed_tasks = [] + + while need_check_req_ids: + still_in_progress = False + for task in tasks: + if task.request_id not in need_check_req_ids: + continue + + result = self.resource_manager.waiting_async_process(task) + if result is False: # async preprocess success + need_check_req_ids.remove(task.request_id) + elif result is True: + still_in_progress = True + elif result is None: # async preprocess failed + failed_tasks.append(task) + need_check_req_ids.remove(task.request_id) + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=task.error_code, + error_msg=task.error_message, + ) + ] + ) + + if still_in_progress: + time.sleep(0.005) + + for tmp_task in failed_tasks: + tasks.remove(tmp_task) + self.resource_manager.pre_recycle_resource(tmp_task.request_id) + + # Send cache info to messager (skip in storage pool mode - messager is bypassed) + if tasks and not envs.FD_PD_TRANSFER_VIA_STORAGE: + self.split_connector.send_cache_info_to_messager(tasks, 0) + + # Fetch requests and add them to the scheduling queue + if tasks: + for task in tasks: + task.metrics.add_req_to_resource_manager_time = time.time() + trace_print(LoggingEventName.RESOURCE_ALLOCATE_START, task.request_id, getattr(task, "user", "")) + self.resource_manager.add_request_in_p(tasks) + self.llm_logger.info(f"P add requests into running queue: {[task.request_id for task in tasks]}") + + return True + + def _fetch_loop(self, fetch_fn, thread_idx: int): + """Fetch loop run by each worker thread.""" + tracing.trace_set_thread_info(f"Prepare Request for Scheduling - thread {thread_idx}") + while self.running: + try: + with self._pause_cond: + self._pause_cond.wait_for(lambda: not self.is_paused) + fetch_fn() + time.sleep(0.02) + except Exception as e: + self.llm_logger.error(f"fetching request error in worker-{thread_idx}: {e} {traceback.format_exc()}") + time.sleep(0.02) + + def _prepare_request_v1(self): + """Prepare request and send to the queue for scheduling""" + tracing.trace_set_thread_info("Prepare Request for Scheduling") + role = self.cfg.scheduler_config.splitwise_role + num_workers = envs.FD_PREFILL_PREPARE_REQ_THREAD_NUM if role == "prefill" else 1 + self.llm_logger.info(f"prepare request for scheduling, role: {role}, num_workers: {num_workers}") + + fetch_fn = { + "mixed": self._fetch_request_mixed, + "prefill": self._fetch_request_prefill, + "decode": self._fetch_request_decode, + }[role] + + self._fetch_threads = [] + for i in range(num_workers): + t = threading.Thread( + target=self._fetch_loop, + args=(fetch_fn, i), + daemon=True, + name=f"fetch-{i}", + ) + t.start() + self._fetch_threads.append(t) + + # Keep this thread alive for graceful shutdown + while self.running: + time.sleep(1.0) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 283693fae8c..996a4f9d68a 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -143,6 +143,12 @@ def start(self, api_server_pid=None): self.engine.create_data_processor() self.data_processor = self.engine.data_processor + # Create RoutingCacheManager when skipping profiling (num_gpu_blocks_override is set) + if not self.do_profile and self.cfg.routing_replay_config.enable_routing_replay: + num_gpu_blocks = self.cfg.cache_config.num_gpu_blocks_override + if num_gpu_blocks is not None: + self.engine._init_routing_cache_manager(num_gpu_blocks) + # If block numer is specified and model is deployed in mixed mode, start cache manager first if not self.do_profile and self.cfg.scheduler_config.splitwise_role != "mixed": if not current_platform.is_intel_hpu(): @@ -458,6 +464,11 @@ def _exit_sub_services(self): if hasattr(self, "zmq_server") and self.zmq_server is not None: self.zmq_server.close() + if hasattr(self, "engine") and hasattr(self.engine, "routing_cache_manager"): + if self.engine.routing_cache_manager is not None: + self.engine.routing_cache_manager.close() + self.engine.routing_cache_manager = None + if hasattr(self, "dp_processed"): for p in self.dp_processed: console_logger.info(f"Waiting for worker {p.pid} to exit") @@ -655,6 +666,9 @@ def _start_worker_service(self): "enable_entropy": self.cfg.model_config.enable_entropy, "ep_prefill_use_worst_num_tokens": self.cfg.parallel_config.ep_prefill_use_worst_num_tokens, "enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule, + "enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask, + "enable_flashinfer_allreduce_fusion": self.cfg.parallel_config.enable_flashinfer_allreduce_fusion, + "enable_moe_scores_elementwise_fuse": self.cfg.scheduler_config.enable_moe_scores_elementwise_fuse, } for worker_flag, value in worker_store_true_flag.items(): if value: @@ -762,6 +776,11 @@ def _stop_profile(self): num_gpu_blocks = self.get_profile_block_num_signal.value[0] self.cfg.cache_config.reset(num_gpu_blocks) self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) + + # Create RoutingCacheManager (SharedMemory) before starting cache service + if self.cfg.routing_replay_config.enable_routing_replay: + self.engine._init_routing_cache_manager(num_gpu_blocks) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": if not current_platform.is_intel_hpu(): device_ids = self.cfg.parallel_config.device_ids.split(",") diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 81fe93e52a4..5958b3d9bd3 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -109,7 +109,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id): if envs.FD_ENABLE_RETURN_TEXT: self.engine.create_data_processor() if self.cfg.scheduler_config.name == "dp": - self.cfg.init_cache_info() + self.cfg.init_pd_info() self.engine.scheduler.start(local_data_parallel_id) if ipc_signal_suffix is not None: @@ -122,7 +122,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id): self.llm_logger.info(f"start expert service {local_data_parallel_id}") if self.cfg.scheduler_config.name == "splitwise": - self.cfg.init_cache_info() + self.cfg.init_pd_info() role = self.cfg.scheduler_config.splitwise_role host_ip = self.cfg.host_ip self.engine.scheduler.start(role, host_ip, self.cfg.register_info) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 0e95cd5e1fb..05ee4a348ea 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -50,10 +50,11 @@ class RequestStatus(Enum): WAITING = 0 - RUNNING = 1 - PREEMPTED = 2 - FINISHED = 3 - ABORT = 4 + RUNNING_PREFILL = 1 + RUNNING_DECODE = 2 + PREEMPTED = 3 + FINISHED = 4 + ABORT = 5 class RequestType(Enum): @@ -205,6 +206,7 @@ def __init__( self.metrics = RequestMetrics() else: self.metrics = metrics + self.metrics.prompt_token_ids_len = self.prompt_token_ids_len # from ChatCompletionRequest or CompletionRequest self.user = user self.metadata = metadata @@ -727,6 +729,10 @@ class CompletionOutput: delta_message: Optional[DeltaMessage] = None multipart: Optional[list[Any]] = None num_image_tokens: Optional[int] = None + # Sparse indices of retained vocab ids: + # - Non-MTP: list[int] + # - MTP: list[list[int]] + sampling_mask: Optional[Any] = None def to_dict(self): """ @@ -745,6 +751,7 @@ def to_dict(self): "text": self.text, "reasoning_content": self.reasoning_content, "reasoning_token_num": self.reasoning_token_num, + "sampling_mask": self.sampling_mask, } @classmethod @@ -872,6 +879,7 @@ class RequestMetrics: speculate_metrics: Optional[SpeculateMetrics] = None # cache related + prompt_token_ids_len: Optional[int] = None gpu_cache_token_num: Optional[int] = 0 cpu_cache_token_num: Optional[int] = 0 storage_cache_token_num: Optional[int] = 0 @@ -1019,6 +1027,7 @@ def __init__( self.ic_req_data = ic_req_data self.prompt_token_ids_len = prompt_token_ids_len self.trace_carrier = trace_carrier + self.routing_data = None # Optional[np.ndarray], [seq_len, num_moe_layers, top_k] if prompt_token_ids is None: self.prompt_token_ids = [] @@ -1114,12 +1123,15 @@ def from_dict(cls, d: dict): d.pop("metrics", None) metrics = None trace_carrier = d.pop("trace_carrier", {}) - return RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier) + routing_data = d.pop("routing_data", None) + obj = RequestOutput(**d, outputs=completion_output, metrics=metrics, trace_carrier=trace_carrier) + obj.routing_data = routing_data + return obj def to_dict(self): """convert RequestOutput into a serializable dict""" - return { + d = { "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, @@ -1137,6 +1149,9 @@ def to_dict(self): "prompt_token_ids_len": self.prompt_token_ids_len, "trace_carrier": self.trace_carrier, } + if self.routing_data is not None: + d["routing_data"] = self.routing_data + return d def get(self, key: str, default_value=None): if hasattr(self, key): diff --git a/fastdeploy/engine/resource_manager.py b/fastdeploy/engine/resource_manager.py index 609c88533bd..173cbdf9dd7 100644 --- a/fastdeploy/engine/resource_manager.py +++ b/fastdeploy/engine/resource_manager.py @@ -368,6 +368,12 @@ def info(self): total_block_number = self.total_block_number() available_block_num = self.available_block_num() used_block_num = total_block_number - available_block_num + blocks_used_by_tasks = set() + for task in self.tasks_list: + if task is not None: + blocks_used_by_tasks.update(getattr(task, "block_tables", [])) + blocks_used_by_tasks.update(getattr(task, "extend_block_tables", [])) + evictable_block_num = used_block_num - len(blocks_used_by_tasks) block_usage = used_block_num / total_block_number * 100 total_batch_number = len(self.stop_flags) available_batch_num = self.available_batch() @@ -375,8 +381,8 @@ def info(self): batch_usage = used_batch_num / total_batch_number * 100 info = ( f"ResourceManager info, " - f"total_block_number: {total_block_number}, total_batch_number: {total_batch_number}, " - f"available_block_num: {available_block_num}, available_batch: {available_batch_num}," + f"total_block_number: {total_block_number}, available_block_num: {available_block_num}, evictable_block_num: {evictable_block_num}, " + f"total_batch_number: {total_batch_number}, available_batch: {available_batch_num}," f"running_reqs: {used_batch_num}, block_usage: {block_usage:.2f}%, batch_usage: {batch_usage:.2f}%" ) return info diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 5af1605cdaf..04ae315f2c1 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -205,29 +205,36 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.need_block_num_map = dict() self.encoder_cache = None - if config.model_config.enable_mm and config.cache_config.max_encoder_cache > 0: + if config.enable_mm_runtime and config.cache_config.max_encoder_cache > 0: self.encoder_cache = EncoderCacheManager(config.cache_config.max_encoder_cache) self.processor_cache = None - if config.model_config.enable_mm and config.cache_config.max_processor_cache > 0: + self.routing_host_view = None # Set by Engine after RoutingHostBuffer creation + + if config.enable_mm_runtime and config.cache_config.max_processor_cache > 0: max_processor_cache_in_bytes = int(config.cache_config.max_processor_cache * 1024 * 1024 * 1024) self.processor_cache = ProcessorCacheManager(max_processor_cache_in_bytes) self.bos_client = None self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4) - self.init_reserve_output_block_num = ( - envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # int - self.decay_output_block_num = ( - envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # float - self.min_reserve_output_block_num = ( - envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL - ) # int - self.current_reserve_output_block_num = self.init_reserve_output_block_num - self.current_reserve_output_block_num_float = self.init_reserve_output_block_num - self.can_relax_prefill_strategy = True + self.use_new_token_ratio_reserve = envs.FD_USE_NEW_TOKEN_RATIO_RESERVE + if self.use_new_token_ratio_reserve: + self.init_new_token_ratio = envs.FD_INIT_NEW_TOKEN_RATIO + self.min_new_token_ratio = envs.FD_MIN_NEW_TOKEN_RATIO + self.new_token_ratio_decay = envs.FD_NEW_TOKEN_RATIO_DECAY + self.clip_max_new_tokens = envs.FD_CLIP_MAX_NEW_TOKENS + self.new_token_ratio = self.init_new_token_ratio + else: + self.init_reserve_output_block_num = envs.FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + self.decay_output_block_num = envs.FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + self.min_reserve_output_block_num = ( + envs.FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL + ) + self.current_reserve_output_block_num = self.init_reserve_output_block_num + self.current_reserve_output_block_num_float = float(self.init_reserve_output_block_num) + self.can_relax_prefill_strategy = True + # Scheduler-side requests that have not been moved into resource manager waiting queue yet. self.scheduler_unhandled_request_num = 0 @@ -238,13 +245,17 @@ def get_new_block_nums(self, request: Request, num_new_tokens: int): block_num = ( request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size - len(request.block_tables) - + block_num = max(block_num, 0) if self.config.speculative_config.method is not None: block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq) else: block_num = min(block_num, self.config.cache_config.max_block_num_per_seq) return block_num + def _is_decoding(self, request) -> bool: + """Return True if the request has finished prefill and is in the decoding phase.""" + return request.num_computed_tokens >= request.need_prefill_tokens + def _prepare_prefill_task(self, request, new_token_num): request.prefill_start_index = request.num_computed_tokens request.prefill_end_index = request.num_computed_tokens + new_token_num @@ -281,7 +292,9 @@ def recycle_abort_task(self, request_id): self.stop_flags[request.idx] = True # 设置停止标志 del self.requests[request_id] del self.req_dict[request_id] - self.to_be_aborted_req_id_set.remove(request_id) + self.to_be_aborted_req_id_set.discard(request_id) + self.waiting_abort_req_id_set.discard(request_id) + llm_logger.debug(f"request_id:{request_id} recycle abort task end") self.update_metrics() def _trigger_abort(self, request_id, scheduled_reqs): @@ -293,7 +306,8 @@ def _trigger_abort(self, request_id, scheduled_reqs): abort_request.cached_block_num = 0 scheduled_reqs.append(self._prepare_abort_task(abort_request)) self.to_be_aborted_req_id_set.add(request_id) - self.waiting_abort_req_id_set.remove(request_id) + self.waiting_abort_req_id_set.discard(request_id) + llm_logger.debug(f"request_id:{request_id} trigger abort") def _info_each_block(self): """ @@ -304,15 +318,6 @@ def _info_each_block(self): f"req idx {req.idx} occupy {len(req.block_tables)} block_tables and {len(req.extend_block_tables)} extend_block_tables" ) - def _can_preempt(self): - """ - cannot preempt request which use extend block - """ - for req in self.running: - if not req.use_extend_tables: - return True - return False - def preempted_all(self): with self.lock: preempted_reqs = [] @@ -347,17 +352,49 @@ def wait_worker_inflight_requests_finish(self, timeout=60): f"still {len(self.to_be_rescheduled_request_id_set)} requests running" ) + def _select_preempt_candidate(self): + # Scan from back to front to find the last preemptable request + preempted_req = None + i = len(self.running) - 1 + while i >= 0: + candidate = self.running[i] + # Skip requests that are not in decode status + if candidate.status != RequestStatus.RUNNING_DECODE: + i -= 1 + continue + # Skip requests using extend tables + if candidate.use_extend_tables: + i -= 1 + continue + # Found a valid preempt target + preempted_req = candidate + break + return preempted_req, i + def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs): """ If the request cannot be scheduled, preempt the running request one by one until it can be scheduled. Last in, first out. + Only requests that is in decode status can be preempted. """ can_schedule = False - while self._can_preempt(): - if not self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): - preempted_req = self.running.pop() - if preempted_req.use_extend_tables: - self.running.insert(0, preempted_req) - continue + while True: + if self.cache_manager.can_allocate_gpu_blocks(num_new_blocks): + # The request can be scheduled. + can_schedule = True + break + else: + # Try to find a candidate request to preempt. + preempted_req, preempted_idx = self._select_preempt_candidate() + if preempted_req is None: + can_schedule = False + llm_logger.warning( + f"Preemption is triggered while no preemptable request can be found, scheduler may be hung! " + f"Running requests: {self.running}" + ) + break + + # Remove the preempted request from the running list + self.running.pop(preempted_idx) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.config.scheduler_config.splitwise_role == "decode": @@ -389,33 +426,82 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re llm_logger.debug( f"preempt {preempted_req.request_id} in idx {preempted_req.idx} with generated ids {preempted_req.output_token_ids}" ) + llm_logger.debug(self.info()) self._info_each_block() + self._reset_reserve_on_preemption() if preempted_req == request: # No more request to preempt. can_schedule = False break - else: - # The request can be scheduled. - can_schedule = True - break - self.current_reserve_output_block_num = self.init_reserve_output_block_num - self.current_reserve_output_block_num_float = self.init_reserve_output_block_num - self.can_relax_prefill_strategy = False + return can_schedule + def _reset_reserve_on_preemption(self): + """Reset reserved blocks on preemption.""" + if self.use_new_token_ratio_reserve: + if not self.running: + self.new_token_ratio = self.init_new_token_ratio + return + total_decoded_tokens = sum(len(req.output_token_ids) for req in self.running) + total_max_new_tokens = 0 + for req in self.running: + max_tokens = req.sampling_params.max_tokens + if max_tokens is None: + max_tokens = self.config.model_config.max_model_len - req.prompt_token_ids_len + total_max_new_tokens += max_tokens + num_running_decode = sum( + [1 if req.num_total_tokens > req.need_prefill_tokens else 0 for req in self.running] + ) + extra_decode_steps = ( + 16 * self.config.cache_config.block_size + ) # consider extra 16 blocks for each running decode request when estimating new token ratio + new_ratio = (total_decoded_tokens + extra_decode_steps * num_running_decode) / (total_max_new_tokens + 1) + self.new_token_ratio = min(new_ratio, self.init_new_token_ratio) + llm_logger.info( + f"Estimate new token ratio for preemption: {self.new_token_ratio}, " + f"total_decoded_tokens={total_decoded_tokens}, total_max_new_tokens={total_max_new_tokens}, num_running_decode={num_running_decode}" + ) + + else: + self.current_reserve_output_block_num = self.init_reserve_output_block_num + self.current_reserve_output_block_num_float = float(self.init_reserve_output_block_num) + self.can_relax_prefill_strategy = False + + def _get_running_request_reserve_blocks(self, request: Request) -> int: + """Estimate KV-cache blocks to reserve for a running request's future decode tokens. + + Aligned with SGLang's per-request budget estimation: + reserved_tokens = min(max_tokens - already_generated, CLIP_MAX_NEW_TOKENS) * new_token_ratio + then ceil-divided by block_size. The ratio decays each scheduling step so that + the reservation gradually relaxes; on preemption it resets to the initial value. + """ + max_tokens = getattr(request.sampling_params, "max_tokens", None) + if max_tokens is None: + max_tokens = self.config.model_config.max_model_len - request.prompt_token_ids_len + remaining_tokens = max_tokens - len(request.output_token_ids) + clipped_remaining = min(remaining_tokens, self.clip_max_new_tokens) + reserved_tokens = max(int(clipped_remaining * self.new_token_ratio), 0) + block_size = self.config.cache_config.block_size + return (reserved_tokens + block_size - 1) // block_size + def _get_can_schedule_prefill_threshold_block(self, num_chunk_new_block): - if self.can_relax_prefill_strategy: - can_schedule_block_num_threshold = num_chunk_new_block + """Compute the minimum free blocks required to admit a new prefill request.""" + if self.use_new_token_ratio_reserve: + reserve_blocks = sum(self._get_running_request_reserve_blocks(req) for req in self.running) + can_schedule_block_num_threshold = num_chunk_new_block + reserve_blocks else: - can_schedule_block_num_threshold = ( - num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num - ) - if self.config.speculative_config.method is not None: - can_schedule_block_num_threshold = min( - can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq + if self.can_relax_prefill_strategy: + can_schedule_block_num_threshold = num_chunk_new_block + else: + can_schedule_block_num_threshold = ( + num_chunk_new_block + len(self.running) * self.current_reserve_output_block_num ) + if self.config.speculative_config.method is not None: + can_schedule_block_num_threshold = min( + can_schedule_block_num_threshold + 1, self.config.cache_config.max_block_num_per_seq + ) return can_schedule_block_num_threshold def _update_mm_hashes(self, request): @@ -550,7 +636,7 @@ def _get_num_new_tokens(self, request, token_budget): num_new_tokens = token_budget // self.config.cache_config.block_size * self.config.cache_config.block_size request.with_image = False - if not self.config.model_config.enable_mm: + if not self.config.enable_mm_runtime: return num_new_tokens inputs = request.multimodal_inputs @@ -686,7 +772,7 @@ def _compute_audio_prefix_count(end_idx, end_patch_idx): num_new_tokens = new_end_idx - pre_end_idx image_mask = input_ids[pre_end_idx:new_end_idx] == image_patch_id - request.with_image = image_mask.any() + request.with_image = bool(image_mask.any()) if request.with_image: pre_boundary_idx = np.searchsorted(img_boundaries_idx, pre_end_idx, side="left").item() if pre_boundary_idx == len(img_boundaries_idx): @@ -749,7 +835,7 @@ def cache_output_tokens(self, request): and self.config.scheduler_config.splitwise_role != "decode" ): with self.lock: - if request.num_computed_tokens >= request.need_prefill_tokens: # request is decoding + if self._is_decoding(request): # request is decoding self.cache_manager.cache_output_blocks(request, self.config.cache_config.block_size) def schedule(self): @@ -768,12 +854,22 @@ def get_enough_request(request, scheduled_reqs): scheduled_reqs: list[Request] = [] preempted_reqs: list[Request] = [] error_reqs: list[tuple[str, str]] = [] - token_budget = self.config.scheduler_config.max_num_batched_tokens + tokens_per_seq = ( + (self.config.speculative_config.num_speculative_tokens + 1) + if self.config.speculative_config is not None and self.config.speculative_config.method is not None + else 1 + ) + num_running_decode_reqs = sum(1 for req in self.running if self._is_decoding(req)) + token_budget = ( + self.config.scheduler_config.max_num_batched_tokens - num_running_decode_reqs * tokens_per_seq + ) need_abort_requests = [] # users trigger abortion + chunk_prefill_in_running_not_satisfied = False # First, schedule the RUNNING requests. req_index = 0 num_decoding_req_nums = 0 + while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] need_block_num = self.need_block_num_signal.value[request.idx] @@ -781,7 +877,7 @@ def get_enough_request(request, scheduled_reqs): self.need_block_num_map[request.request_id] = SignalConsumer(need_block_num, 1) self.need_block_num_signal.value[request.idx] = 0 - if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding + if self._is_decoding(request): # to be decoding if ( self.config.scheduler_config.splitwise_role == "prefill" ): # do not need to schedule for decoding @@ -828,7 +924,7 @@ def get_enough_request(request, scheduled_reqs): # Prepare decoding task scheduled_reqs.append(self._prepare_decode_task(request)) num_decoding_req_nums += 1 - token_budget -= 1 + # Decode token cost has been pre-deducted upfront (num_running_decode_reqs * tokens_per_seq). if ( request.use_extend_tables and request.request_id not in self.using_extend_tables_req_id @@ -906,27 +1002,29 @@ def _allocate_decode_and_extend(): req_index += 1 continue num_new_block = self.get_new_block_nums(request, num_new_tokens) - # Allocate blocks to prefill - if self.cache_manager.can_allocate_gpu_blocks(num_new_block): - request.block_tables.extend( - self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id) + if self.config.scheduler_config.splitwise_role == "prefill": + # for prefill instance, do not set threshold for running requests + can_schedule_block_num_threshold = 0 + else: + can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( + num_new_block ) - # Prepare prefill task - scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) - else: # Not enough blocks to allocate, trigger preemption - can_schedule = self._trigger_preempt(request, num_new_block, preempted_reqs, scheduled_reqs) - if not can_schedule: - break + # Allocate blocks to prefill + if self.cache_manager.can_allocate_gpu_blocks(can_schedule_block_num_threshold): request.block_tables.extend( self.cache_manager.allocate_gpu_blocks(num_new_block, request.request_id) ) # Prepare prefill task scheduled_reqs.append(self._prepare_prefill_task(request, num_new_tokens)) + else: # Not enough blocks to allocate + chunk_prefill_in_running_not_satisfied = True + break # For chunk prefill request, if not satisfy condition for prefill, just break token_budget -= num_new_tokens request.num_computed_tokens += num_new_tokens if ( self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode" + and self.config.scheduler_config.splitwise_role != "prefill" ): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens @@ -938,10 +1036,16 @@ def _allocate_decode_and_extend(): self.running.remove(request) # Second, schedule the WAITING requests. - if not preempted_reqs: + if (not preempted_reqs) and (not chunk_prefill_in_running_not_satisfied): skip_requests: list[Request] = [] while self.waiting and token_budget > 0: - if len(self.running) == self.max_num_seqs: + if ( + len(self.running) + + len(self.to_be_rescheduled_request_id_set) + + len(self.to_be_aborted_req_id_set) + + sum([req.status == RequestStatus.PREEMPTED for req in self.waiting]) + >= self.max_num_seqs + ): break request = self.waiting[0] @@ -966,9 +1070,12 @@ def _allocate_decode_and_extend(): self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend ): - if not self.cache_manager.can_allocate_gpu_blocks( + can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( (request.need_prefill_tokens + self.config.cache_config.block_size - 1) // self.config.cache_config.block_size + ) + if not self.cache_manager.can_allocate_gpu_blocks( + can_schedule_block_num_threshold ): # to prevent block allocation for matching in hierarchical cache and cause dead lock break success = self.get_prefix_cached_blocks(request) @@ -1018,7 +1125,7 @@ def _allocate_decode_and_extend(): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) - request.status = RequestStatus.RUNNING + request.status = RequestStatus.RUNNING_PREFILL if self.config.scheduler_config.splitwise_role == "mixed": allocated_position = self.get_available_position() request.idx = allocated_position @@ -1027,6 +1134,7 @@ def _allocate_decode_and_extend(): self.req_dict[request.request_id] = allocated_position llm_logger.debug(f"req_id:{request.request_id} allocate pos end") else: + # Warning: _free_blocks before update_cache_blocks may cause storage blocks leak if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) break @@ -1042,9 +1150,12 @@ def _allocate_decode_and_extend(): self.cache_manager.num_cpu_blocks > 0 or self.config.cache_config.kvcache_storage_backend ): - if not self.cache_manager.can_allocate_gpu_blocks( + can_schedule_block_num_threshold = self._get_can_schedule_prefill_threshold_block( (request.need_prefill_tokens + self.config.cache_config.block_size - 1) // self.config.cache_config.block_size + ) + if not self.cache_manager.can_allocate_gpu_blocks( + can_schedule_block_num_threshold ): # to prevent block allocation for matching in hierarchical cache and cause dead lock break success = self.get_prefix_cached_blocks(request) @@ -1087,8 +1198,9 @@ def _allocate_decode_and_extend(): self.cache_manager.update_cache_blocks( request, self.config.cache_config.block_size, request.num_computed_tokens ) - request.status = RequestStatus.RUNNING + request.status = RequestStatus.RUNNING_PREFILL else: + # Warning: _free_blocks before update_cache_blocks may cause storage blocks leak if self.config.cache_config.enable_prefix_caching: self._free_blocks(request) break @@ -1101,14 +1213,20 @@ def _allocate_decode_and_extend(): if scheduled_reqs: llm_logger.debug(f"schedued_reqs: {scheduled_reqs}") - self.current_reserve_output_block_num_float -= self.decay_output_block_num - self.current_reserve_output_block_num = max( - int(self.current_reserve_output_block_num_float), - self.min_reserve_output_block_num, - 0, - ) - if self.current_reserve_output_block_num == 0: - self.can_relax_prefill_strategy = True + if self.use_new_token_ratio_reserve: + self.new_token_ratio = max( + self.new_token_ratio - self.new_token_ratio_decay, + self.min_new_token_ratio, + ) + else: + self.current_reserve_output_block_num_float -= self.decay_output_block_num + self.current_reserve_output_block_num = max( + int(self.current_reserve_output_block_num_float), + self.min_reserve_output_block_num, + 0, + ) + if self.current_reserve_output_block_num == 0: + self.can_relax_prefill_strategy = True self._log_console_scheduler_metrics(scheduled_reqs) @@ -1245,6 +1363,8 @@ def get_prefix_cached_blocks(self, request: Request): Match and fetch cache for a task. """ try: + trace_print(LoggingEventName.PREPARE_PREFIX_CACHE_START, request.request_id, getattr(request, "user", "")) + (common_block_ids, matched_token_num, metrics) = self.cache_manager.request_match_blocks( request, self.config.cache_config.block_size ) @@ -1292,11 +1412,14 @@ def get_prefix_cached_blocks(self, request: Request): request.metrics.storage_cache_token_num = metrics["storage_match_token_num"] request.metrics.cpu_cache_prepare_time = metrics["cpu_cache_prepare_time"] request.metrics.storage_cache_prepare_time = metrics["storage_cache_prepare_time"] + request.metrics.prompt_token_ids_len = request.prompt_token_ids_len main_process_metrics.prefix_cache_token_num.inc(request.num_computed_tokens) main_process_metrics.prefix_gpu_cache_token_num.inc(request.metrics.gpu_cache_token_num) main_process_metrics.prefix_cpu_cache_token_num.inc(request.metrics.cpu_cache_token_num) + trace_print(LoggingEventName.PREPARE_PREFIX_CACHE_END, request.request_id, getattr(request, "user", "")) + return True except Exception as e: llm_logger.error(f"prefix match blocks error: {e}, {str(traceback.format_exc())} waiting reschedule...") @@ -1327,6 +1450,7 @@ def pre_recycle_resource(self, request_id: str): def add_request_in_p(self, requests: list[Request]): with self.lock: for request in requests: + request.status = RequestStatus.RUNNING_PREFILL self.running.append(request) def preallocate_resource_in_p(self, request: Request): @@ -1368,6 +1492,11 @@ def preallocate_resource_in_p(self, request: Request): self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position + + self.cache_manager.update_cache_blocks( + request, self.config.cache_config.block_size, request.need_prefill_tokens + ) + return True else: self._free_blocks(request) @@ -1417,7 +1546,6 @@ def preallocate_resource_in_d(self, request: Request): request.disaggregate_info["block_tables"] = request.block_tables allocated_position = self.get_available_position() request.idx = allocated_position - self.tasks_list[request.idx] = request self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position @@ -1455,14 +1583,46 @@ def add_prefilled_request(self, request_output: RequestOutput): ): request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids) request.need_prefill_tokens = len(request.prompt_token_ids) + 1 + request.status = RequestStatus.RUNNING_DECODE request_output.metrics.decode_recv_req_time = request.metrics.decode_recv_req_time request_output.metrics.decode_preallocate_req_time = request.metrics.decode_preallocate_req_time request.metrics = copy.deepcopy(request_output.metrics) request.metrics.decode_inference_start_time = time.time() request.metrics.update_decoder_start_time() + + # [R3] Write P's prefill routing data into D's routing_host_buffer + if ( + self.routing_host_view is not None + and hasattr(request_output, "routing_data") + and request_output.routing_data is not None + ): + try: + self._write_prefill_routing_to_host_buffer(request, request_output.routing_data) + except Exception as e: + llm_logger.warning(f"[R3] Failed to write prefill routing for {request_output.request_id}: {e}") + + self.tasks_list[request.idx] = request self.running.append(request) + def _write_prefill_routing_to_host_buffer(self, request, routing_data): + """ + Write P's prefill routing data into D's routing_host_buffer. + Uses D's block_tables to compute slot_mapping. + """ + import math + + seq_len = routing_data.shape[0] + block_size = self.config.cache_config.block_size + num_blocks_needed = math.ceil(seq_len / block_size) + block_ids = request.block_tables[:num_blocks_needed] + + positions = np.arange(seq_len) + block_indices = positions // block_size + offsets = positions % block_size + slot_mapping = np.array(block_ids)[block_indices] * block_size + offsets + self.routing_host_view.scatter(slot_mapping, routing_data) + def _free_blocks(self, request: Request): if self.config.cache_config.enable_prefix_caching and self.config.scheduler_config.splitwise_role != "decode": self.cache_manager.release_block_ids(request) @@ -1490,7 +1650,7 @@ def finish_requests_async(self, request_ids: Union[str, Iterable[str]]): def finish_requests(self, request_ids: Union[str, Iterable[str]]): llm_logger.info(f"recycle resources for requests: {request_ids}") - self.update_metrics(verbose=True) + self.update_metrics() try: if isinstance(request_ids, str): request_ids = (request_ids,) @@ -1522,11 +1682,18 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): del self.requests[req_id] if req_id in self.req_dict: del self.req_dict[req_id] + self.waiting_abort_req_id_set.discard(req_id) + self.to_be_aborted_req_id_set.discard(req_id) # Do not block the main thread here # Write cache to storage if kvcache_storage_backend is enabled for req in need_postprocess_reqs: - if self.config.scheduler_config.splitwise_role == "decode": + if envs.FD_PD_TRANSFER_VIA_STORAGE: + # Storage pool mode: P already writes cache in token_processor before notifying D, + # only D needs to write here (including output tokens generated during decode) + if self.config.scheduler_config.splitwise_role == "decode": + self.cache_manager.write_all_cache_to_storage(req) + elif self.config.scheduler_config.splitwise_role == "decode": # D instance uses simplified write method (does not rely on Radix Tree) self.cache_manager.write_cache_to_storage_decode(req) else: @@ -1543,7 +1710,7 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]): except Exception as e: llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}") finally: - self.update_metrics(verbose=True) + self.update_metrics() def clear_data(self): self.waiting: deque[Request] = deque() @@ -1552,18 +1719,25 @@ def clear_data(self): def update_metrics(self, verbose=False): # Update metrics - num_tasks = sum([1 if task else 0 for task in self.tasks_list]) + num_requests_running = len(self.running) + num_requests_waiting = len(self.waiting) + num_requests_queuing = max(int(getattr(self, "scheduler_unhandled_request_num", 0) or 0), 0) blocks_used_by_tasks = set() for task in self.tasks_list: if task is not None: - blocks_used_by_tasks.update(task.block_tables) + blocks_used_by_tasks.update(getattr(task, "block_tables", [])) + blocks_used_by_tasks.update(getattr(task, "extend_block_tables", [])) main_process_metrics.available_gpu_block_num.set(self.total_block_number() - len(blocks_used_by_tasks)) main_process_metrics.batch_size.set(self.max_num_seqs - self.available_batch()) main_process_metrics.gpu_cache_usage_perc.set(self.get_gpu_cache_usage_perc()) - main_process_metrics.num_requests_running.set(len(self.running)) - main_process_metrics.num_requests_waiting.set(num_tasks - len(self.running)) + main_process_metrics.num_requests_running.set(num_requests_running) + main_process_metrics.num_requests_waiting.set(num_requests_waiting) + main_process_metrics.num_requests_queuing.set(num_requests_queuing) if verbose: - llm_logger.info(f"update metrics: running={len(self.running)}, waiting={num_tasks - len(self.running)}") + llm_logger.info( + f"update metrics: running={num_requests_running}, " + f"waiting={num_requests_waiting}, queuing={num_requests_queuing}" + ) def log_status(self): llm_logger.info( @@ -1589,6 +1763,13 @@ def _log_console_scheduler_metrics(self, scheduled_reqs: list[Request | Schedule total_blocks = self.total_block_number() free_blocks = self.available_block_num() used_blocks = max(total_blocks - free_blocks, 0) + # Evictable = used blocks not held by any running task + blocks_used_by_tasks = set() + for task in self.tasks_list: + if task is not None: + blocks_used_by_tasks.update(getattr(task, "block_tables", [])) + blocks_used_by_tasks.update(getattr(task, "extend_block_tables", [])) + evictable_blocks = used_blocks - len(blocks_used_by_tasks) tokens_used = used_blocks * self.config.cache_config.block_size token_usage = used_blocks / total_blocks if total_blocks > 0 else 0.0 running_cnt = len(self.running) @@ -1598,13 +1779,26 @@ def _log_console_scheduler_metrics(self, scheduled_reqs: list[Request | Schedule prefill_reqs = [r for r in scheduled_reqs if isinstance(r, Request) and r.task_type == RequestType.PREFILL] has_decode = any(getattr(r, "task_type", None) == RequestType.DECODE for r in scheduled_reqs) - self.scheduler_metrics_logger.log_prefill_batch( - prefill_reqs=prefill_reqs, - running_cnt=running_cnt, - queue_cnt=queue_cnt, - tokens_used=tokens_used, - token_usage=token_usage, - ) + if self.config.scheduler_config.splitwise_role == "decode": + self.scheduler_metrics_logger.log_decode_bootstrap_batch( + prefill_reqs=prefill_reqs, + running_cnt=running_cnt, + queue_cnt=queue_cnt, + tokens_used=tokens_used, + token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, + ) + else: + self.scheduler_metrics_logger.log_prefill_batch( + prefill_reqs=prefill_reqs, + running_cnt=running_cnt, + queue_cnt=queue_cnt, + tokens_used=tokens_used, + token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, + ) if has_decode: has_prefill = len(prefill_reqs) > 0 graph_opt_cfg = self.config.graph_opt_config @@ -1629,4 +1823,6 @@ def _log_console_scheduler_metrics(self, scheduled_reqs: list[Request | Schedule tokens_used=tokens_used, token_usage=token_usage, use_cudagraph=use_decode_cudagraph, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, ) diff --git a/fastdeploy/engine/sched/scheduler_metrics_logger.py b/fastdeploy/engine/sched/scheduler_metrics_logger.py index 9e08375a395..b989584fe07 100644 --- a/fastdeploy/engine/sched/scheduler_metrics_logger.py +++ b/fastdeploy/engine/sched/scheduler_metrics_logger.py @@ -29,9 +29,10 @@ class SchedulerMetricsLogger: DEFAULT_DECODE_LOG_INTERVAL = 5 - def __init__(self, enabled: bool = True, dp_rank: int = 0) -> None: + def __init__(self, enabled: bool = True, dp_rank: int = 0, splitwise_role: str = "mixed") -> None: self.enabled = enabled self.dp_rank = dp_rank + self.splitwise_role = splitwise_role decode_log_interval = envs.FD_CONSOLE_DECODE_LOG_INTERVAL if decode_log_interval <= 0: decode_log_interval = self.DEFAULT_DECODE_LOG_INTERVAL @@ -65,13 +66,16 @@ def on_decode_tokens(self, num_tokens: int) -> None: with self._lock: self._decode_tokens_since_last += num_tokens - def log_prefill_batch( + def _log_prefill_like_batch( self, + batch_name: str, prefill_reqs: Iterable, running_cnt: int, queue_cnt: int, tokens_used: int, token_usage: float, + free_blocks: int = 0, + evictable_blocks: int = 0, ) -> None: if not self.enabled: return @@ -89,17 +93,62 @@ def log_prefill_batch( cached_tokens += getattr(req, "num_cached_tokens", 0) or 0 msg = ( - "Prefill batch, " + f"{batch_name}, " f"dp_rank: {self.dp_rank}, " + f"splitwise_role: {self.splitwise_role}, " f"#new-seq: {len(prefill_reqs)}, " f"#new-token: {new_tokens}, " f"#cached-token: {cached_tokens}, " f"token usage: {token_usage:.2f}, " + f"#free-block: {free_blocks}, " + f"#evictable-block: {evictable_blocks}, " f"#running-req: {running_cnt}, " f"#queue-req: {queue_cnt}, " ) self._logger.info(msg) + def log_prefill_batch( + self, + prefill_reqs: Iterable, + running_cnt: int, + queue_cnt: int, + tokens_used: int, + token_usage: float, + free_blocks: int = 0, + evictable_blocks: int = 0, + ) -> None: + self._log_prefill_like_batch( + batch_name="Prefill batch", + prefill_reqs=prefill_reqs, + running_cnt=running_cnt, + queue_cnt=queue_cnt, + tokens_used=tokens_used, + token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, + ) + + def log_decode_bootstrap_batch( + self, + prefill_reqs: Iterable, + running_cnt: int, + queue_cnt: int, + tokens_used: int, + token_usage: float, + free_blocks: int = 0, + evictable_blocks: int = 0, + ) -> None: + self._log_prefill_like_batch( + batch_name="Decode bootstrap batch from prefill", + prefill_reqs=prefill_reqs, + running_cnt=running_cnt, + queue_cnt=queue_cnt, + tokens_used=tokens_used, + token_usage=token_usage, + free_blocks=free_blocks, + evictable_blocks=evictable_blocks, + ) + def log_decode_batch( self, running_cnt: int, @@ -107,6 +156,8 @@ def log_decode_batch( tokens_used: int, token_usage: float, use_cudagraph: bool, + free_blocks: int = 0, + evictable_blocks: int = 0, ) -> None: if not self.enabled: return @@ -126,9 +177,12 @@ def log_decode_batch( msg = ( "Decode batch, " f"dp_rank: {self.dp_rank}, " + f"splitwise_role: {self.splitwise_role}, " f"#running-req: {running_cnt}, " f"#token: {tokens_used}, " f"token usage: {token_usage:.2f}, " + f"#free-block: {free_blocks}, " + f"#evictable-block: {evictable_blocks}, " f"cuda graph: {use_cudagraph}, " f"gen throughput (token/s): {throughput:.2f}, " f"#queue-req: {queue_cnt}, " diff --git a/fastdeploy/entrypoints/api_server.py b/fastdeploy/entrypoints/api_server.py index 4f4d7f2250c..e182eb61fe9 100644 --- a/fastdeploy/entrypoints/api_server.py +++ b/fastdeploy/entrypoints/api_server.py @@ -123,7 +123,7 @@ def main(): parser = FlexibleArgumentParser() parser.add_argument("--port", default=9904, type=int, help="port to the http server") parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server") - parser.add_argument("--workers", default=1, type=int, help="number of workers") + parser.add_argument("--workers", default=4, type=int, help="number of workers") parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() launch_api_server(args) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index f03a18594de..7fcb2e0fbc4 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -84,7 +84,7 @@ class EngineClient: def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers: int = 1, max_logprobs: int = 20): self.fd_config = fd_config self.tensor_parallel_size = self.fd_config.parallel_config.tensor_parallel_size - self.enable_mm = self.fd_config.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.max_logprobs = max_logprobs input_processor = InputPreprocessor( self.fd_config.model_config, @@ -93,6 +93,7 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers self.fd_config.mm_processor_kwargs, self.fd_config.tool_parser, self.enable_mm and self.fd_config.cache_config.max_processor_cache > 0, + enable_mm_runtime=self.enable_mm, ) self.enable_logprob = self.fd_config.model_config.enable_logprob self.data_processor = input_processor.create_processor() @@ -358,6 +359,7 @@ async def add_requests(self, task): task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) min_tokens = task.get("min_tokens", 1) + if "messages" in task: task["messages"] = None api_server_logger.info(f"task['max_tokens']:{task['max_tokens']}") @@ -596,6 +598,10 @@ def check_health(self, time_interval_threashold=30): async def run_control_method(self, request: ControlRequest): api_server_logger.info(f"Received control request: {request}") + request_id = request.request_id + dealer, response_queue = await self.connection_manager.get_connection(request_id) + if not envs.ZMQ_SEND_BATCH_DATA: + dealer.write([b"", request_id.encode("utf-8")]) req_dict = request.to_dict() if envs.ZMQ_SEND_BATCH_DATA: req_dict["zmq_worker_pid"] = self.worker_pid @@ -603,10 +609,6 @@ async def run_control_method(self, request: ControlRequest): self.zmq_client.send_json(req_dict) else: self.zmq_client.send_pyobj(req_dict) - request_id = request.request_id - dealer, response_queue = await self.connection_manager.get_connection(request_id) - if not envs.ZMQ_SEND_BATCH_DATA: - dealer.write([b"", request_id.encode("utf-8")]) try: # todo: support user specified timeout. default 600s is enough for most control cases response = await asyncio.wait_for(response_queue.get(), timeout=600) @@ -1044,6 +1046,18 @@ async def abort(self, request_id, n=1) -> None: api_server_logger.info("Aborted request(s) %s.", ",".join(request_ids)) + async def abort_reqs(self, req_ids=None, abort_all=False): + """ + Fire-and-forget: abort multiple requests in one ZMQ message. + Used by /v1/abort_requests API. + """ + data = { + "status": RequestStatus.ABORT.value, + "abort_all": abort_all, + "req_ids": req_ids or [], + } + self._send_task(data) + def process_messages(self, messages): for message in messages: if message["role"] == "assistant" and "tool_calls" in message: diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index b96d93ab312..4e76a62ca65 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -486,13 +486,8 @@ async def abort_requests(request: Request): if not abort_all and not req_ids: return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"}) - control_request = ControlRequest( - request_id=f"control-{uuid.uuid4()}", - method="abort_requests", - args={"abort_all": abort_all, "req_ids": req_ids or []}, - ) - control_response = await app.state.engine_client.run_control_method(control_request) - return control_response.to_api_json_response() + await app.state.engine_client.abort_reqs(req_ids=req_ids or [], abort_all=abort_all) + return Response(status_code=200) def wrap_streaming_generator(original_generator: AsyncGenerator): diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 3560f3a8aef..a546017d30f 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -268,8 +268,10 @@ class ChatCompletionResponseChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort", "pd_reschedule"]] speculate_metrics: Optional[SpeculateMetrics] = None + # Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token + sampling_mask: Optional[List[List[int]]] = None class ChatCompletionResponse(BaseModel): @@ -283,6 +285,7 @@ class ChatCompletionResponse(BaseModel): model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo + routed_experts: Optional[str] = None class LogProbEntry(BaseModel): @@ -333,7 +336,10 @@ class ChatCompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + # Per-token index list of retained positions after top_p sampling. + # Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step). + sampling_mask: Optional[List[List[int]]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort", "pd_reschedule"]] = None arrival_time: Optional[float] = None speculate_metrics: Optional[SpeculateMetrics] = None @@ -369,7 +375,7 @@ class CompletionResponseChoice(BaseModel): draft_logprobs: Optional[CompletionLogprobs] = None prompt_logprobs: Optional[PromptLogprobs] = None reasoning_content: Optional[str] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort", "pd_reschedule"]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None speculate_metrics: Optional[SpeculateMetrics] = None @@ -385,6 +391,7 @@ class CompletionResponse(BaseModel): model: str choices: List[CompletionResponseChoice] usage: UsageInfo + routed_experts: Optional[str] = None class CompletionLogprobs(BaseModel): @@ -415,7 +422,7 @@ class CompletionResponseStreamChoice(BaseModel): prompt_tokens: Optional[str] = None completion_tokens: Optional[str] = None reasoning_content: Optional[str] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort", "pd_reschedule"]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None speculate_metrics: Optional[SpeculateMetrics] = None diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py index 41761963be8..acdfcaeade4 100644 --- a/fastdeploy/entrypoints/openai/response_processors.py +++ b/fastdeploy/entrypoints/openai/response_processors.py @@ -89,7 +89,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ decode_type = request_output["outputs"].get("decode_type", 0) or 0 if decode_type == 0: # text tts = req_id in self._audio_buffer - if token_ids[-1] == self.eos_token_id: + if token_ids and token_ids[-1] == self.eos_token_id: all_audio_tokens = self._audio_buffer.pop(req_id, []) else: all_audio_tokens = None @@ -186,7 +186,7 @@ async def process_response_chat(self, request_outputs, stream, include_stop_str_ else: self.accumulate_token_ids(request_output) token_ids = request_output["outputs"]["token_ids"] - if token_ids[-1] == self.eos_token_id: + if token_ids and token_ids[-1] == self.eos_token_id: multipart = [] num_image_tokens = 0 for part in self._multipart_buffer: diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index eb106f6550f..1bbee7b07f9 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -435,6 +435,11 @@ async def chat_completion_stream_generator( delta=delta_message, logprobs=logprobs_res, draft_logprobs=draft_logprobs_res, + sampling_mask=( + self._make_sampling_mask_list(output["sampling_mask"]) + if output.get("sampling_mask") is not None + else None + ), arrival_time=arrival_time, speculate_metrics=output_speculate_metrics, ) @@ -580,8 +585,10 @@ async def chat_completion_full_generator( decoder_base_url=self.tokenizer_base_url, ) prompt_logprobs_res_list = [[] for _ in range(num_choices)] + sampling_mask_list = [[] for _ in range(num_choices)] speculate_metrics = [None for _ in range(num_choices)] choices = [] + routing_data_result = None while num_choices > 0: if self.engine_client.check_model_weight_status(): return ErrorResponse( @@ -614,10 +621,22 @@ async def chat_completion_full_generator( request=request, ) async for data in generator: - if data.get("error_code", 200) != 200: - raise ValueError("{}".format(data["error_msg"])) idx = int(data["request_id"].split("_")[-1]) - # api_server_logger.debug(f"Client {request_id} received: {data}") + if data.get("error_code", 200) != 200: + # Error response - include already-generated tokens in the response + data["outputs"] = { + "text": "", + "completion_tokens": "", + "reasoning_content": "", + "tool_calls": None, + "reasoning_token_num": 0, + "num_image_tokens": 0, + "token_ids": [], + "top_logprobs": None, + "draft_top_logprobs": None, + } + data["metrics"] = data.get("metrics") or {} + data["finished"] = True previous_num_tokens[idx] += len(data["outputs"]["token_ids"]) completion_token_ids[idx].extend(data["outputs"]["token_ids"]) # The logprob for handling the response @@ -660,6 +679,9 @@ async def chat_completion_full_generator( ) if prompt_logprobs_res: prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res)) + output_sampling_mask = output.get("sampling_mask", None) + if output_sampling_mask is not None: + sampling_mask_list[idx].append(self._make_sampling_mask_list(output_sampling_mask)) speculate_metrics[idx] = data["metrics"].get("speculate_metrics", None) if data["finished"]: trace_carrier = data.get("trace_carrier") @@ -695,10 +717,20 @@ async def chat_completion_full_generator( draft_logprob_contents=draft_logprob_contents, response_processor=response_processor, prompt_logprobs_res_list=prompt_logprobs_res_list, + sampling_mask_list=sampling_mask_list, max_tokens=max_tokens, speculate_metrics=speculate_metrics[idx], ) choices.append(choice) + if data.get("routing_data") is not None: + import base64 + + import numpy as np + + rd = data["routing_data"] + if not isinstance(rd, np.ndarray): + rd = np.array(rd) + routing_data_result = base64.b64encode(rd.tobytes()).decode("utf-8") finally: trace_print(LoggingEventName.POSTPROCESSING_END, request_id, getattr(request, "user", "")) tracing.trace_req_finish(request_id) @@ -730,6 +762,7 @@ async def chat_completion_full_generator( model=model_name, choices=choices, usage=usage, + routed_experts=routing_data_result, ) api_server_logger.info(f"Chat response: {res.model_dump_json()}") return res @@ -749,6 +782,7 @@ async def _create_chat_completion_choice( logprob_contents: list, draft_logprob_contents: list, prompt_logprobs_res_list: list, + sampling_mask_list: list, response_processor: ChatResponseProcessor, max_tokens: int, speculate_metrics: SpeculateMetrics | None, @@ -756,6 +790,26 @@ async def _create_chat_completion_choice( idx = int(data["request_id"].split("_")[-1]) output = data["outputs"] + finish_reason = "stop" + if previous_num_tokens != max_tokens: + finish_reason = "stop" + if output.get("tool_calls"): + finish_reason = "tool_calls" + else: + finish_reason = "length" + if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]: + finish_reason = "recover_stop" + + if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]: + finish_reason = "abort" + + if data.get("error_msg", None) is not None and "PD Error" in data["error_msg"]: + finish_reason = "pd_reschedule" + + return_completion_token_ids = False + if request.return_token_ids or finish_reason == "pd_reschedule": + return_completion_token_ids = True + if output is not None and output.get("metrics") and output["metrics"].get("request_start_time"): main_process_metrics.e2e_request_latency.observe( time.time() - data.get("metrics").get("request_start_time") @@ -765,7 +819,7 @@ async def _create_chat_completion_choice( reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_calls"), prompt_token_ids=prompt_token_ids if request.return_token_ids else None, - completion_token_ids=completion_token_ids if request.return_token_ids else None, + completion_token_ids=completion_token_ids if return_completion_token_ids else None, prompt_tokens=prompt_tokens if request.return_token_ids else None, completion_tokens=output.get("completion_tokens") if request.return_token_ids else None, ) @@ -787,29 +841,23 @@ async def _create_chat_completion_choice( if prompt_logprobs_res_list[idx]: prompt_logprobs_full_res = prompt_logprobs_res_list[idx] + # Flatten per-step List[List[int]] into a single List[List[int]] over all tokens. + sampling_mask_full_res = None + if sampling_mask_list and sampling_mask_list[idx]: + sampling_mask_full_res = [mask for step in sampling_mask_list[idx] for mask in step] + num_cached_tokens[idx] = data.get("num_cached_tokens", 0) num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0) num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0) num_image_tokens[idx] = output.get("num_image_tokens", 0) or 0 - finish_reason = "stop" - if previous_num_tokens != max_tokens: - finish_reason = "stop" - if output.get("tool_calls"): - finish_reason = "tool_calls" - else: - finish_reason = "length" - if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]: - finish_reason = "recover_stop" - - if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]: - finish_reason = "abort" return ChatCompletionResponseChoice( index=idx, message=message, logprobs=logprobs_full_res, draft_logprobs=draft_logprobs_full_res, prompt_logprobs=prompt_logprobs_full_res, + sampling_mask=sampling_mask_full_res, finish_reason=finish_reason, speculate_metrics=speculate_metrics, ) @@ -1000,3 +1048,18 @@ def _make_logprob_dict( ) for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) } + + @staticmethod + def _make_sampling_mask_list(sampling_mask) -> List[List[int]]: + """Wrap sampling_mask into a uniform List[List[int]] format. + + sampling_mask is already in sparse-index form (no bool-to-index conversion needed): + Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]] + MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...] + """ + assert sampling_mask is not None + if sampling_mask and isinstance(sampling_mask[0], list): + # MTP: already List[List[int]], return as-is + return sampling_mask + # Non-MTP: already List[int], wrap in outer list for uniform format + return [sampling_mask] diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index b277576a1fc..e7941197f37 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -305,7 +305,15 @@ async def completion_full_generator( for data in response: rid = int(data["request_id"].split("_")[-1]) if data.get("error_code", 200) != 200: - raise ValueError("{}".format(data["error_msg"])) + data["outputs"] = { + "text": "", + "completion_tokens": "", + "token_ids": [], + "top_logprobs": None, + "draft_top_logprobs": None, + } + data["metrics"] = data.get("metrics") or {} + data["finished"] = True output = data["outputs"] output_top_logprobs = output.get("top_logprobs") or None @@ -727,13 +735,19 @@ def request_output_to_completion_response( ) if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]: finish_reason = "abort" + if final_res.get("error_msg", None) is not None and "PD Error" in final_res["error_msg"]: + finish_reason = "pd_reschedule" + + return_completion_token_ids = False + if request.return_token_ids or finish_reason == "pd_reschedule": + return_completion_token_ids = True choice_data = CompletionResponseChoice( token_ids=token_ids, index=len(choices), text=output_text, prompt_token_ids=prompt_token_ids if request.return_token_ids else None, - completion_token_ids=completion_token_ids if request.return_token_ids else None, + completion_token_ids=completion_token_ids if return_completion_token_ids else None, completion_tokens=output.get("completion_tokens") if request.return_token_ids else None, prompt_tokens=( prompt_tokens_list[idx // (1 if request.n is None else request.n)] @@ -772,12 +786,24 @@ def request_output_to_completion_response( ) del request + routed_experts = None + if final_res_batch and final_res_batch[-1].get("routing_data") is not None: + import base64 + + import numpy as np + + rd = final_res_batch[-1]["routing_data"] + if not isinstance(rd, np.ndarray): + rd = np.array(rd) + routed_experts = base64.b64encode(rd.tobytes()).decode("utf-8") + return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, + routed_experts=routed_experts, ) async def _call_process_response_dict(self, res, request, stream): diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py index f4556a3679f..7435dbce490 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -111,7 +111,7 @@ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) ) ) return ExtractedToolCallInformation( - tools_called=True, + tools_called=len(tool_calls) > 0, tool_calls=tool_calls, ) except Exception: @@ -182,11 +182,13 @@ def extract_tool_calls_streaming( logger.debug("attempting to close tool call, but no tool call") return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") - if diff: - if '"}' not in delta_text: + if diff is not None: + if "}" not in delta_text: + return None + end_loc = delta_text.rindex("}") + diff = delta_text[:end_loc] + if not diff: return None - end_loc = delta_text.rindex('"}') - diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not " "been streamed yet: %s", diff, @@ -248,15 +250,15 @@ def extract_tool_calls_streaming( prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get("arguments") - if not cur_arguments and not prev_arguments: + if cur_arguments is None and prev_arguments is None: logger.debug("Skipping text %s - no arguments", delta_text) delta = None - elif not cur_arguments and prev_arguments: + elif cur_arguments is None and prev_arguments is not None: logger.error("should be impossible to have arguments reset " "mid-call. skipping streaming anything.") delta = None - elif cur_arguments and not prev_arguments: + elif cur_arguments is not None and prev_arguments is None: function_name = current_tool_call.get("name") match = re.search( r'\{"name":\s*"' + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)', @@ -265,6 +267,19 @@ def extract_tool_calls_streaming( ) if match: cur_arguments_json = match.group(1) + # When tool_call_portion is complete JSON, the regex + # (.*) over-captures the outer closing brace of the + # tool call object. Strip it from both + # cur_arguments_json and delta_text, consistent with + # the both-have-arguments branch handling. + try: + json.loads(tool_call_portion) + if cur_arguments_json.endswith("}"): + cur_arguments_json = cur_arguments_json[:-1] + if delta_text.rstrip().endswith("}"): + delta_text = delta_text.rstrip()[:-1] + except Exception: + pass else: cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) @@ -287,7 +302,7 @@ def extract_tool_calls_streaming( ) self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - elif cur_arguments and prev_arguments: + elif cur_arguments is not None and prev_arguments is not None: try: json.loads(tool_call_portion) is_complete_json = True diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index baa428b5003..57976b0f5b2 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -341,9 +341,10 @@ async def close(self): def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + _is_multi_server = os.environ.get("FD_ENABLE_MULTI_API_SERVER") == "1" parser.add_argument("--port", default=8000, type=int, help="port to the http server") parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server") - parser.add_argument("--workers", default=1, type=int, help="number of workers") + parser.add_argument("--workers", default=1 if _is_multi_server else 4, type=int, help="number of workers") parser.add_argument("--metrics-port", default=None, type=int, help="port for metrics server") parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server") parser.add_argument( @@ -352,7 +353,9 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, help="max waiting time for connection, if set value -1 means no waiting time limit", ) - parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") + parser.add_argument( + "--max-concurrency", default=512 if _is_multi_server else 2048, type=int, help="max concurrency" + ) parser.add_argument( "--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. " diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 0c7ac3e22b1..f030478f1b3 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -70,7 +70,9 @@ def _validate_split_kv_size(value: int) -> int: # Set attention backend. "NATIVE_ATTN", "APPEND_ATTN" # and "MLA_ATTN" can be set currently. "FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"), - # Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently. + # enable decode attention + "USE_DECODE_UNIFIED_ATTENTION": lambda: bool(int(os.getenv("USE_DECODE_UNIFIED_ATTENTION", "0"))), + # Set sampling class. "base", "base_non_truncated", "air", "rejection" and "triton" can be set currently. "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), # Set moe backend."cutlass","marlin", "triton", "flashinfer-cutlass", "flashinfer-cutedsl" and "flashinfer-trtllm" can be set currently. "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), @@ -145,6 +147,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_ZMQ_CONTROL_CMD_SERVER_PORTS": lambda: os.getenv("FD_ZMQ_CONTROL_CMD_SERVER_PORTS", "8202"), # Whether to enable the decode caches requests for preallocating resource "FD_ENABLE_CACHE_TASK": lambda: os.getenv("FD_ENABLE_CACHE_TASK", "0"), + # Batched token timeout in EP + "FD_EP_BATCHED_TOKEN_TIMEOUT": lambda: float(os.getenv("FD_EP_BATCHED_TOKEN_TIMEOUT", "0.1")), # Max pre-fetch requests number in PD "FD_EP_MAX_PREFETCH_TASK_NUM": lambda: int(os.getenv("FD_EP_MAX_PREFETCH_TASK_NUM", "8")), # Enable or disable model caching. @@ -152,6 +156,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_ENABLE_MODEL_LOAD_CACHE": lambda: bool(int(os.getenv("FD_ENABLE_MODEL_LOAD_CACHE", "0"))), # Whether to clear cpu cache when clearing model weights. "FD_ENABLE_SWAP_SPACE_CLEARING": lambda: int(os.getenv("FD_ENABLE_SWAP_SPACE_CLEARING", "0")), + # AS-only flush mode: AttentionStore only reports cache index without storing actual data. + "FD_AS_ONLY_FLUSH": lambda: bool(int(os.getenv("FD_AS_ONLY_FLUSH", "0"))), # enable return text, used when FD_ENABLE_INTERNAL_ADAPTER=1 "FD_ENABLE_RETURN_TEXT": lambda: bool(int(os.getenv("FD_ENABLE_RETURN_TEXT", "0"))), # Used to truncate the string inserted during thinking when reasoning in a model. ( for ernie-45-vl, \n\n\n for ernie-x1) @@ -174,8 +180,8 @@ def _validate_split_kv_size(value: int) -> int: "PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES": lambda: int( os.getenv("PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", "1") ), - "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")), - "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), + "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "1")), + "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "1")), "FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")), "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), "FD_ENABLE_ASYNC_LLM": lambda: int(os.getenv("FD_ENABLE_ASYNC_LLM", "0")), @@ -187,6 +193,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")), # "Enable FP8 calibration on HPU" "FD_HPU_MEASUREMENT_MODE": lambda: os.getenv("FD_HPU_MEASUREMENT_MODE", "0"), + # Number of worker threads for prepare requests in prefill instance + "FD_PREFILL_PREPARE_REQ_THREAD_NUM": lambda: int(os.getenv("FD_PREFILL_PREPARE_REQ_THREAD_NUM", "3")), "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), "FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE": lambda: int( os.getenv("FD_ENABLE_REQUEST_DISCONNECT_STOP_INFERENCE", "1") @@ -210,16 +218,12 @@ def _validate_split_kv_size(value: int) -> int: "FD_XPU_MOE_FFN_QUANT_TYPE_MAP": lambda: os.getenv("FD_XPU_MOE_FFN_QUANT_TYPE_MAP", ""), # Whether to enable low latency in mixed scenario "FD_XPU_ENABLE_MIXED_EP_MODE": lambda: bool(int(os.getenv("FD_XPU_ENABLE_MIXED_EP_MODE", "0"))), - # Whether to use phi FP8 quantization,if 1,use paddle default. - "FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))), - # Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc, - # intended for training alignment. Defaults to 0 (disabled). - "FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))), - # Whether to use phi MOE permute,if 1,use paddle op. - "FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))), - # Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function - "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), # Reserve output blocks for decoding requests when schedule new prefill requests + "FD_INIT_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_INIT_NEW_TOKEN_RATIO", "0.7")), + "FD_MIN_NEW_TOKEN_RATIO": lambda: float(os.getenv("FD_MIN_NEW_TOKEN_RATIO", "0.1")), + "FD_NEW_TOKEN_RATIO_DECAY": lambda: float(os.getenv("FD_NEW_TOKEN_RATIO_DECAY", "0.001")), + "FD_CLIP_MAX_NEW_TOKENS": lambda: int(os.getenv("FD_CLIP_MAX_NEW_TOKENS", "4096")), + # Legacy reserve block env vars (kept for backwards compatibility, no longer used) "FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "16") ), @@ -229,6 +233,9 @@ def _validate_split_kv_size(value: int) -> int: "FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL": lambda: int( os.getenv("FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL", "0") ), + # When True, use per-request new_token_ratio to estimate reserved blocks (SGLang-style). + # When False, fall back to the legacy fixed-block reservation strategy. + "FD_USE_NEW_TOKEN_RATIO_RESERVE": lambda: bool(int(os.getenv("FD_USE_NEW_TOKEN_RATIO_RESERVE", "1"))), # Timeout for worker process health check in seconds "FD_WORKER_ALIVE_TIMEOUT": lambda: int(os.getenv("FD_WORKER_ALIVE_TIMEOUT", "30")), # File path for file storage backend @@ -247,6 +254,10 @@ def _validate_split_kv_size(value: int) -> int: "FD_DETERMINISTIC_LOG_MODE": lambda: bool(int(os.getenv("FD_DETERMINISTIC_LOG_MODE", "0"))), # Whether to use PD REORDER, can set 0 or 1 "FD_PD_REORDER": lambda: int(os.getenv("FD_PD_REORDER", "0")), + # PD disaggregation cache transfer mode: + # 0 (default): Direct transfer mode, P writes cache to D's GPU via RDMA/IPC + # 1: Storage pool mode, P writes cache to global storage pool, D reads from storage pool + "FD_PD_TRANSFER_VIA_STORAGE": lambda: int(os.getenv("FD_PD_TRANSFER_VIA_STORAGE", "0")), # Whether to enable KV cache lock, enforcing mutual exclusion between # PrefixCacheManager and Worker when accessing GPU KV cache. # Under certain DP+EP configurations, concurrent access (even read-only) @@ -266,6 +277,26 @@ def _validate_split_kv_size(value: int) -> int: "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # Whether to use GDR CheckpointTransfer for dynamic weight updates. + "FD_USE_GDR_CHECKPOINT_TRANSFER": lambda: bool(int(os.getenv("FD_USE_GDR_CHECKPOINT_TRANSFER", "0"))), + # train-infer consistency, used in RL + # Whether to align RoPE and moe gate precision with training + "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), + # Whether to use phi FP8 quantization,if 1,use paddle default. + "FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))), + # Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc, + # intended for training alignment. Defaults to 0 (disabled). + "FD_USE_PHI_MOE_TOPK": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_TOPK", "0"))), + # Whether to use phi MOE permute,if 1,use paddle op. + "FD_USE_PHI_MOE_PERMUTE": lambda: bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "0"))), + # Whether to use phi rms_norm,if 1,use paddle op. + "FD_USE_PHI_RMSNORM": lambda: bool(int(os.getenv("FD_USE_PHI_RMSNORM", "0"))), + # Control class SiluAndMul to use swiglu or fusid_bias_act operator in the forward_cuda function + "FD_SiluAndMul_USE_PHI_SWIGLU": lambda: bool(int(os.getenv("FD_SiluAndMul_USE_PHI_SWIGLU", "0"))), + # Whether to enable FP8 quantization with pow2scale. + "FD_FP8_QUANT_WITH_POW2SCALE": lambda: bool(int(os.getenv("FD_FP8_QUANT_WITH_POW2SCALE", "0"))), + # Whether to enable top_p=1.0 optimization. + "FD_ENABLE_TOP_P_ONE_OPT": lambda: bool(int(os.getenv("FD_ENABLE_TOP_P_ONE_OPT", "1"))), } diff --git a/fastdeploy/eplb/async_expert_loader.py b/fastdeploy/eplb/async_expert_loader.py index 0cf9eb0453e..2832a7f635f 100644 --- a/fastdeploy/eplb/async_expert_loader.py +++ b/fastdeploy/eplb/async_expert_loader.py @@ -24,8 +24,24 @@ import paddle try: - from cuda import cudart -except ImportError: + import cuda as _cuda_pkg + + _cuda_ver = getattr(_cuda_pkg, "__version__", None) + if _cuda_ver is None: + # cuda-python >= 13.x does not expose a top-level __version__; + # detect the version via the cuda-bindings package. + import importlib.metadata as _meta + + _cuda_ver = _meta.version("cuda-bindings") + _cuda_major = int(_cuda_ver.split(".")[0]) + if _cuda_major >= 13: + from cuda.bindings import runtime as cudart + else: + from cuda import cudart +except Exception as _e: + import warnings + + warnings.warn(f"cuda-python import failed, async_expert_loader will be unavailable: {_e}") cudart = None from fastdeploy.config import EPLBConfig @@ -98,6 +114,7 @@ def create_mmap(model_name: List, ep_rank: int, ep_size: int, shm_uuid: str, epl raise ImportError( "cuda-python not installed. Install the version matching your CUDA toolkit:\n" " CUDA 12.x → pip install cuda-python==12.*\n" + " CUDA 13.x → pip install cuda-python cuda-bindings\n" ) # Register memory with CUDA diff --git a/fastdeploy/input/base_processor.py b/fastdeploy/input/base_processor.py index 357339be766..cdc01067bfd 100644 --- a/fastdeploy/input/base_processor.py +++ b/fastdeploy/input/base_processor.py @@ -236,6 +236,17 @@ def process_response_dict(self, response_dict, **kwargs): ``stream`` is read from ``kwargs`` (default: True). """ + # Error responses (e.g., preemption) have outputs=None or error_code!=200. + # Skip token decoding and return as-is to let upstream error handling take over. + if isinstance(response_dict, dict): + outputs = response_dict.get("outputs") + error_code = response_dict.get("error_code", 200) + else: + outputs = getattr(response_dict, "outputs", None) + error_code = getattr(response_dict, "error_code", 200) + if outputs is None or error_code != 200: + return response_dict + stream = kwargs.get("stream", True) if stream: return self.process_response_dict_streaming(response_dict, **kwargs) @@ -412,6 +423,9 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): if len(request["prompt_token_ids"]) == 0: raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs") + if request.get("completion_token_ids"): + request["prompt_token_ids"].extend(request["completion_token_ids"]) + # truncate prompts that exceed the length limit if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len: request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1] diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index 755f0612def..6467f6a89ac 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -48,6 +48,7 @@ def __init__( mm_processor_kwargs: Optional[Dict[str, Any]] = None, tool_parser: str = None, enable_processor_cache: bool = False, + enable_mm_runtime: Optional[bool] = None, ) -> None: self.model_config = model_config self.model_name_or_path = self.model_config.model @@ -56,6 +57,7 @@ def __init__( self.mm_processor_kwargs = mm_processor_kwargs self.tool_parser = tool_parser self.enable_processor_cache = enable_processor_cache + self.enable_mm_runtime = self.model_config.enable_mm if enable_mm_runtime is None else enable_mm_runtime def create_processor(self): reasoning_parser_obj = None @@ -77,10 +79,11 @@ def create_processor(self): reasoning_parser_obj=reasoning_parser_obj, tool_parser_obj=tool_parser_obj, mm_processor_kwargs=self.mm_processor_kwargs, + enable_mm_runtime=self.enable_mm_runtime, ) except Exception as e: logger.info(f"Plugin input processor not available ({e}), using built-in processor") - if not self.model_config.enable_mm: + if not self.enable_mm_runtime: from fastdeploy.input.text_processor import TextProcessor tokenizer_type = "ernie4_5" if ErnieArchitectures.contains_ernie_arch(architecture) else "auto" diff --git a/fastdeploy/input/utils.py b/fastdeploy/input/utils.py index 19a86f31574..74e1cf2c6fb 100644 --- a/fastdeploy/input/utils.py +++ b/fastdeploy/input/utils.py @@ -82,6 +82,7 @@ def process_stop_token_ids( update_stop_seq_fn: Callable[[List[str]], Tuple[List[List[int]], List[int]]], ) -> None: stop_token_ids_final = [] + stop_seqs_len_final = [] if request.get("stop_token_ids") is not None: stop_token_ids = request.get("stop_token_ids") @@ -89,17 +90,19 @@ def process_stop_token_ids( if isinstance(stop_token_ids[0], int): # List[int] -> List[List[int]] stop_token_ids_final.extend([[t] for t in stop_token_ids]) + stop_seqs_len_final.extend([1] * len(stop_token_ids)) elif isinstance(stop_token_ids[0], list): # Already List[List[int]] stop_token_ids_final.extend(stop_token_ids) + stop_seqs_len_final.extend([len(seq) for seq in stop_token_ids]) stop_sequences = request.get("stop", []) if stop_sequences: - stop_seqs, _ = update_stop_seq_fn(stop_sequences) + stop_seqs, stop_seqs_actual_lens = update_stop_seq_fn(stop_sequences) stop_token_ids_final.extend(stop_seqs) + stop_seqs_len_final.extend(stop_seqs_actual_lens) # Update request if stop_token_ids_final: - stop_seqs_len = [len(seq) for seq in stop_token_ids_final] request["stop_token_ids"] = stop_token_ids_final - request["stop_seqs_len"] = stop_seqs_len + request["stop_seqs_len"] = stop_seqs_len_final diff --git a/fastdeploy/inter_communicator/engine_cache_queue.py b/fastdeploy/inter_communicator/engine_cache_queue.py index 535dc1dc4c3..97a08cd4e88 100644 --- a/fastdeploy/inter_communicator/engine_cache_queue.py +++ b/fastdeploy/inter_communicator/engine_cache_queue.py @@ -24,7 +24,7 @@ Value, ValueProxy, ) -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union from fastdeploy.utils import get_logger @@ -39,7 +39,7 @@ class EngineCacheQueue: def __init__( self, - address: Tuple[str, int] = ("127.0.0.1", 56666), + address: Union[Tuple[str, int], str] = ("127.0.0.1", 56666), authkey: bytes = b"cache_queue_service", is_server: bool = False, num_client: int = 1, # tensor parallel size @@ -62,7 +62,7 @@ def __init__( TODO(liyonghua): Remove multi-DP initialization. Each DP will have its own cache queue. """ - self.address: Tuple[str, int] = address + self.address: Union[Tuple[str, int], str] = address self.authkey: bytes = authkey self.is_server: bool = is_server self.num_client: int = num_client @@ -210,8 +210,10 @@ class QueueManager(BaseManager): QueueManager.register("get_swap_storage_to_gpu_barrier") QueueManager.register("get_swap_to_storage_barrier") + logger.info(f"Try to connect QueueManager, address: {self.address}") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() + logger.info(f"Connected to QueueManager, address: {self.address}") # Get proxy objects for shared resources self.transfer_task_queue = self.manager.get_transfer_task_queue(self.local_data_parallel_id) @@ -246,7 +248,7 @@ def get_server_port(self) -> int: Returns the actual port that the server instance is listening on. Calling this method only makes sense on instances where is_server=True. """ - if not self.is_server: + if not self.is_server or isinstance(self.address, str): raise RuntimeError("Only the server instance can provide the port.") return self.address[1] diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index b64fcacda33..2cb1246aad3 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -92,7 +92,6 @@ class QueueManager(BaseManager): Value("i", 0) for _ in range(self.local_data_parallel_size) ] self.finished_req_list = [list() for _ in range(self.local_data_parallel_size)] - self.finished_add_cache_task_list = [list() for _ in range(self.local_data_parallel_size)] self.cache_infos_init: List[List[Any]] = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_list = [list() for _ in range(self.local_data_parallel_size)] self.connect_rdma_tasks_response_list = [list() for _ in range(self.local_data_parallel_size)] @@ -110,9 +109,6 @@ class QueueManager(BaseManager): self.connect_task_response_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] # connect rdma task response - self.finish_add_cache_task_lock_init: List[threading.Lock] = [ - threading.Lock() for _ in range(self.local_data_parallel_size) - ] # finish add cache task self.finish_send_cache_lock_init: List[threading.Lock] = [ threading.Lock() for _ in range(self.local_data_parallel_size) ] # finish send cache @@ -124,18 +120,12 @@ class QueueManager(BaseManager): self.client_get_connect_task_response_flag_init: List[List[int]] = [ [0] * self.num_client for _ in range(self.local_data_parallel_size) ] - self.client_get_finished_add_cache_task_flag_init: List[List[int]] = [ - [0] * self.num_client for _ in range(self.local_data_parallel_size) - ] self.client_get_finish_send_cache_flag_init: List[List[int]] = [ [0] * self.num_client for _ in range(self.local_data_parallel_size) ] self.can_put_next_connect_task_response_flag_init: List[Value] = [ Value("i", 1) for _ in range(self.local_data_parallel_size) ] - self.can_put_next_add_task_finished_flag_init: List[Value] = [ - Value("i", 1) for _ in range(self.local_data_parallel_size) - ] self.can_put_next_send_cache_finished_flag_init: List[Value] = [ Value("i", 1) for _ in range(self.local_data_parallel_size) ] @@ -147,9 +137,6 @@ class QueueManager(BaseManager): self.get_connect_task_response_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] - self.finish_add_cache_task_barrier = [ - threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) - ] self.begin_send_cache_barrier = [ threading.Barrier(self.num_client) for _ in range(self.local_data_parallel_size) ] @@ -188,11 +175,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.client_get_connect_task_response_flag_init[idx], proxytype=ListProxy, ) - QueueManager.register( - "get_client_get_finished_add_cache_task_flag_init", - callable=lambda idx: self.client_get_finished_add_cache_task_flag_init[idx], - proxytype=ListProxy, - ) QueueManager.register( "get_client_get_finish_send_cache_flag_init", callable=lambda idx: self.client_get_finish_send_cache_flag_init[idx], @@ -218,11 +200,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.can_put_next_connect_task_response_flag_init[idx], proxytype=ValueProxy, ) - QueueManager.register( - "get_can_put_next_add_task_finished_flag", - callable=lambda idx: self.can_put_next_add_task_finished_flag_init[idx], - proxytype=ValueProxy, - ) QueueManager.register( "get_can_put_next_send_cache_finished_flag", callable=lambda idx: self.can_put_next_send_cache_finished_flag_init[idx], @@ -239,11 +216,6 @@ class QueueManager(BaseManager): callable=lambda idx: self.connect_task_response_lock_init[idx], proxytype=AcquirerProxy, ) - QueueManager.register( - "get_finish_add_cache_task_lock", - callable=lambda idx: self.finish_add_cache_task_lock_init[idx], - proxytype=AcquirerProxy, - ) QueueManager.register( "get_finish_send_cache_lock", callable=lambda idx: self.finish_send_cache_lock_init[idx], @@ -268,12 +240,6 @@ class QueueManager(BaseManager): "get_finish_request_queue", callable=lambda idx: self.finished_req_list[idx], proxytype=ListProxy ) - QueueManager.register( - "get_finish_add_cache_task_queue", - callable=lambda idx: self.finished_add_cache_task_list[idx], - proxytype=ListProxy, - ) - QueueManager.register( "get_cache_infos", callable=lambda idx: self.cache_infos_init[idx], @@ -321,12 +287,6 @@ class QueueManager(BaseManager): "get_cache_info_barrier", callable=lambda idx: self.get_cache_info_barrier[idx], ) - - QueueManager.register( - "get_finish_add_cache_task_barrier", - callable=lambda idx: self.finish_add_cache_task_barrier[idx], - ) - QueueManager.register( "get_worker_process_tp_barrier", callable=lambda idx: self.worker_process_tp_barrier[idx], @@ -351,13 +311,11 @@ class QueueManager(BaseManager): QueueManager.register("get_exist_tasks_inter_signal") QueueManager.register("get_connected_client_counter") QueueManager.register("get_finish_request_queue") - QueueManager.register("get_finish_add_cache_task_queue") QueueManager.register("get_cache_infos") QueueManager.register("get_client_read_info_flag") QueueManager.register("get_lock_info") QueueManager.register("get_disaggregate_requests") QueueManager.register("get_finish_request_barrier") - QueueManager.register("get_finish_add_cache_task_barrier") QueueManager.register("get_connect_task_barrier") QueueManager.register("get_connect_task_response_barrier") QueueManager.register("get_finish_send_cache_barrier") @@ -366,16 +324,13 @@ class QueueManager(BaseManager): QueueManager.register("get_connect_rdma_tasks") QueueManager.register("get_client_get_connect_task_flag") QueueManager.register("get_client_get_connect_task_response_flag") - QueueManager.register("get_client_get_finished_add_cache_task_flag_init") QueueManager.register("get_client_get_finish_send_cache_flag_init") QueueManager.register("get_connect_rdma_tasks_responses") QueueManager.register("get_connect_task_lock") QueueManager.register("get_connect_task_response_lock") - QueueManager.register("get_finish_add_cache_task_lock") QueueManager.register("get_finish_send_cache_lock") QueueManager.register("get_worker_process_tp_barrier") QueueManager.register("get_can_put_next_connect_task_response_flag") - QueueManager.register("get_can_put_next_add_task_finished_flag") QueueManager.register("get_can_put_next_send_cache_finished_flag") self.manager = QueueManager(address=self.address, authkey=self.authkey) self._connect_with_retry() @@ -398,9 +353,6 @@ class QueueManager(BaseManager): # p/d 分离获取 self.disaggregate_requests = self.manager.get_disaggregate_requests(self.local_data_parallel_id) self.finish_request_barrier = self.manager.get_finish_request_barrier(self.local_data_parallel_id) - self.finish_add_cache_task_barrier = self.manager.get_finish_add_cache_task_barrier( - self.local_data_parallel_id - ) self.connect_task_barrier = self.manager.get_connect_task_barrier(self.local_data_parallel_id) self.connect_task_response_barrier = self.manager.get_connect_task_response_barrier( self.local_data_parallel_id @@ -410,9 +362,6 @@ class QueueManager(BaseManager): self.begin_send_cache_barrier = self.manager.get_begin_send_cache_barrier(self.local_data_parallel_id) self.worker_process_tp_barrier = self.manager.get_worker_process_tp_barrier(self.local_data_parallel_id) self.finished_send_cache_list = self.manager.get_finish_request_queue(self.local_data_parallel_id) - self.finished_add_cache_task_list = self.manager.get_finish_add_cache_task_queue( - self.local_data_parallel_id - ) # p/d互联 self.connect_rdma_tasks = self.manager.get_connect_rdma_tasks(self.local_data_parallel_id) self.client_get_connect_task_flag = self.manager.get_client_get_connect_task_flag( @@ -421,9 +370,6 @@ class QueueManager(BaseManager): self.client_get_connect_task_response_flag = self.manager.get_client_get_connect_task_response_flag( self.local_data_parallel_id ) - self.client_get_finished_add_cache_task_flag = ( - self.manager.get_client_get_finished_add_cache_task_flag_init(self.local_data_parallel_id) - ) self.client_get_finish_send_cache_flag = self.manager.get_client_get_finish_send_cache_flag_init( self.local_data_parallel_id ) @@ -433,12 +379,8 @@ class QueueManager(BaseManager): ) self.connect_task_lock = self.manager.get_connect_task_lock(self.local_data_parallel_id) self.connect_task_response_lock = self.manager.get_connect_task_response_lock(self.local_data_parallel_id) - self.finish_add_cache_task_lock = self.manager.get_finish_add_cache_task_lock(self.local_data_parallel_id) self.finish_send_cache_lock = self.manager.get_finish_send_cache_lock(self.local_data_parallel_id) - self.can_put_next_add_task_finished_flag = self.manager.get_can_put_next_add_task_finished_flag( - self.local_data_parallel_id - ) self.can_put_next_connect_task_response_flag = self.manager.get_can_put_next_connect_task_response_flag( self.local_data_parallel_id ) @@ -549,7 +491,6 @@ def put_tasks(self, tasks: List[Any]) -> None: self.lock.release() time.sleep(0.001) self.lock.acquire() - if envs.FD_ENABLE_MAX_PREFILL or envs.FD_ENABLE_E2W_TENSOR_CONVERT: # multimodal input numpy -> tensor to_tensor(tasks[0]) @@ -571,7 +512,6 @@ def get_tasks(self) -> Tuple[List[Any], bool]: """ tasks: List[Any] = list() self.lock.acquire() - tasks.extend(self.tasks) self.client_read_flag[self.client_id] = 1 all_client_read: bool = np.sum(self.client_read_flag) == self.num_client @@ -758,54 +698,6 @@ def get_finished_req(self) -> str: self.finish_send_cache_lock.release() return response - def put_finished_add_cache_task_req(self, req_ids) -> None: - """ - Put finished request ID into the queue. - - Args: - req_ids: Request ID to be added to the queue - """ - self.finish_add_cache_task_lock.acquire() - while not self.can_put_next_add_task_finished_flag.get(): - self.finish_add_cache_task_lock.release() - time.sleep(0.001) - self.finish_add_cache_task_lock.acquire() - self.finished_add_cache_task_list.append(req_ids) - self.client_get_finished_add_cache_task_flag[self.client_id] = 1 - all_client_put: bool = np.sum(self.client_get_finished_add_cache_task_flag) == self.num_client - if all_client_put: - self.can_put_next_add_task_finished_flag.set(0) - self.finish_add_cache_task_lock.release() - return all_client_put - - def get_finished_add_cache_task_req(self) -> str: - """ - Get finished request ID from the queue. - - Returns: - str: Finished request ID - """ - response = [] - self.finish_add_cache_task_lock.acquire() - if len(self.finished_add_cache_task_list) == 0: - self.finish_add_cache_task_lock.release() - return response - while sum(self.client_get_finished_add_cache_task_flag) < self.num_client: - self.finish_add_cache_task_lock.release() - time.sleep(0.001) - self.finish_add_cache_task_lock.acquire() - if len(self.finished_add_cache_task_list) > 0: - response = self.finished_add_cache_task_list[0] - for tmp_response in self.finished_add_cache_task_list: - assert ( - tmp_response == response - ), f"Inconsistent responses across workers: expected {response}, got {tmp_response}" - self.finished_add_cache_task_list[:] = list() - self.client_get_finished_add_cache_task_flag[:] = [0] * self.num_client - self.can_put_next_add_task_finished_flag.set(1) - self.finish_add_cache_task_lock.release() - return response - def disaggregate_queue_empty(self): """ Check if the disaggregated task queue is empty. @@ -850,3 +742,13 @@ def cleanup(self): """ if self.manager is not None and self.is_server: self.manager.shutdown() + + def is_broken(self): + try: + self.manager.connect() + return False + except (ConnectionRefusedError, ConnectionResetError, BrokenPipeError, EOFError, OSError): + llm_logger.error("Failed to connect to engine worker queue") + return True + except Exception: + return False diff --git a/fastdeploy/metrics/benchmark_metrics_logger.py b/fastdeploy/metrics/benchmark_metrics_logger.py new file mode 100644 index 00000000000..7e381fb3cc6 --- /dev/null +++ b/fastdeploy/metrics/benchmark_metrics_logger.py @@ -0,0 +1,222 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import os +import threading +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +import numpy as np + +from fastdeploy.config import BenchmarkMetricsConfig + + +@dataclass(slots=True) +class CompletedRequestRecord: + """Raw timing data collected when a request completes.""" + + request_id: str + completion_time: float + arrival_time: float + inference_start_time: float + first_token_time: float + last_token_time: float + input_len: int + output_len: int + num_cached_tokens: int = 0 + itl_samples: list = field(default_factory=list) + + +class BenchmarkMetricsLogger: + """ + In-process performance monitoring that produces metrics aligned with + benchmark_serving.py. Uses a lock-free deque for data collection and + a background daemon thread for stats computation and file I/O. + """ + + def __init__(self, config: BenchmarkMetricsConfig, log_dir: str, dp_rank: int = 0): + self.config = config + self.enabled = config.enable + self.dp_rank = dp_rank + + if config.window_mode == "sliding" and config.window_size > 0: + self._window: deque = deque(maxlen=config.window_size) + else: + self._window: deque = deque() + + self._pending: deque = deque() + self._condition = threading.Condition() + self._stop_event = threading.Event() + + os.makedirs(log_dir, exist_ok=True) + self._file_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + self._file = open(self._file_path, "a", encoding="utf-8") + + self._thread = threading.Thread( + target=self._writer_loop, + daemon=True, + name=f"BenchmarkMetricsLogger-dp{dp_rank}", + ) + self._thread.start() + + def on_request_completed(self, record: CompletedRequestRecord) -> None: + """Called from token processor on request completion. Lock-free append.""" + self._pending.append(record) + with self._condition: + self._condition.notify() + + def _writer_loop(self) -> None: + """Background thread: wait for new records, compute stats, write JSONL.""" + while not self._stop_event.is_set(): + with self._condition: + self._condition.wait(timeout=1.0) + self._process_pending() + + def _process_pending(self) -> None: + """Process all pending records, write one JSONL line per record.""" + while True: + try: + record = self._pending.popleft() + except IndexError: + break + self._window.append(record) + stats = self._compute_rolling_stats() + line = json.dumps(stats, ensure_ascii=False) + self._file.write(line + "\n") + # Tumbling window: clear after reaching window_size + if ( + self.config.window_mode == "tumbling" + and self.config.window_size > 0 + and len(self._window) >= self.config.window_size + ): + self._window.clear() + self._file.flush() + + def _compute_rolling_stats(self) -> dict: + """Compute aggregate statistics over the current window.""" + records = list(self._window) + n = len(records) + if n == 0: + return {"timestamp": datetime.now().isoformat(), "completed": 0} + + selected = self.config.selected_metrics + percentile_values = self.config.percentile_values + + ttfts = [] + s_ttfts = [] + tpots = [] + all_itls = [] + e2els = [] + s_e2els = [] + decode_speeds = [] + input_lens = [] + s_input_lens = [] + output_lens = [] + + for r in records: + if r.first_token_time and r.arrival_time: + ttfts.append((r.first_token_time - r.arrival_time) * 1000) + if r.first_token_time and r.inference_start_time: + s_ttfts.append((r.first_token_time - r.inference_start_time) * 1000) + if r.output_len > 1 and r.first_token_time and r.arrival_time: + e2el_s = r.last_token_time - r.arrival_time + ttft_s = r.first_token_time - r.arrival_time + tpots.append(((e2el_s - ttft_s) / (r.output_len - 1)) * 1000) + if r.itl_samples: + all_itls.extend([x * 1000 for x in r.itl_samples]) + if r.last_token_time and r.arrival_time: + e2els.append((r.last_token_time - r.arrival_time) * 1000) + if r.last_token_time and r.inference_start_time: + s_e2els.append((r.last_token_time - r.inference_start_time) * 1000) + if r.output_len > 1 and r.first_token_time and r.last_token_time: + decode_time = r.last_token_time - r.first_token_time + if decode_time > 0: + decode_speeds.append((r.output_len - 1) / decode_time) + input_lens.append(r.num_cached_tokens) + s_input_lens.append(r.input_len) + output_lens.append(r.output_len) + + # Throughput: based on window time span + total_input = sum(s_input_lens) + total_output = sum(output_lens) + if n >= 2: + duration = records[-1].completion_time - records[0].arrival_time + else: + duration = 0.0 + + result: dict[str, Any] = { + "timestamp": datetime.now().isoformat(), + "window_size": self.config.window_size, + "window_mode": self.config.window_mode, + "completed": n, + "total_input_tokens": total_input, + "total_output_tokens": total_output, + } + + if duration > 0: + result["request_throughput"] = round(n / duration, 2) + result["output_throughput"] = round(total_output / duration, 2) + result["total_throughput"] = round((total_input + total_output) / duration, 2) + + if "ttft" in selected: + result["ttft_ms"] = self._stats(ttfts, percentile_values) + if "s_ttft" in selected: + result["s_ttft_ms"] = self._stats(s_ttfts, percentile_values) + if "tpot" in selected: + result["tpot_ms"] = self._stats(tpots, percentile_values) + if "s_itl" in selected: + result["s_itl_ms"] = self._stats(all_itls, percentile_values) + if "e2el" in selected: + result["e2el_ms"] = self._stats(e2els, percentile_values) + if "s_e2el" in selected: + result["s_e2el_ms"] = self._stats(s_e2els, percentile_values) + if "s_decode" in selected: + result["s_decode"] = self._stats(decode_speeds, percentile_values) + if "input_len" in selected: + result["input_len"] = self._stats(input_lens, percentile_values) + if "s_input_len" in selected: + result["s_input_len"] = self._stats(s_input_lens, percentile_values) + if "output_len" in selected: + result["output_len"] = self._stats(output_lens, percentile_values) + + return result + + @staticmethod + def _stats(values: list, percentiles: list[float]) -> dict: + """Compute mean/median/percentiles for a list of values.""" + if not values: + return {} + arr = np.array(values) + result = { + "mean": round(float(np.mean(arr)), 2), + "median": round(float(np.median(arr)), 2), + } + for p in percentiles: + key = f"p{int(p)}" if int(p) == p else f"p{p}" + result[key] = round(float(np.percentile(arr, p)), 2) + return result + + def shutdown(self) -> None: + """Stop the writer thread and close the file.""" + self._stop_event.set() + with self._condition: + self._condition.notify() + self._thread.join(timeout=5) + self._process_pending() + self._file.close() diff --git a/fastdeploy/metrics/metrics.py b/fastdeploy/metrics/metrics.py index 42fd0231bf0..0daa36ad58a 100644 --- a/fastdeploy/metrics/metrics.py +++ b/fastdeploy/metrics/metrics.py @@ -136,6 +136,7 @@ class MetricsManager: num_requests_running: "Gauge" num_requests_waiting: "Gauge" + num_requests_queuing: "Gauge" time_to_first_token: "Histogram" time_per_output_token: "Histogram" request_inference_time: "Histogram" @@ -153,7 +154,6 @@ class MetricsManager: spec_decode_num_emitted_tokens_total: "Gauge" spec_decode_draft_single_head_acceptance_rate: "list[Gauge]" - # for YIYAN Adapter prefix_cache_token_num: "Counter" prefix_gpu_cache_token_num: "Counter" prefix_cpu_cache_token_num: "Counter" @@ -192,6 +192,11 @@ class MetricsManager: request_prompt_tokens: "Histogram" request_token_ratio: "Histogram" + # for pd + decode_preallocated_req_num: "Gauge" + reschedule_req_num: "Counter" + failed_recv_first_token_req_num: "Counter" + # 定义所有指标配置 # gauge指标在多进程中,会有pid隔离,需要特殊处理,因此手动定义出来 @@ -205,7 +210,13 @@ class MetricsManager: "num_requests_waiting": { "type": Gauge, "name": "fastdeploy:num_requests_waiting", - "description": "Number of requests currently waiting", + "description": "Number of requests currently waiting in resource manager", + "kwargs": {}, + }, + "num_requests_queuing": { + "type": Gauge, + "name": "fastdeploy:num_requests_queuing", + "description": "Number of requests currently queuing in local scheduler", "kwargs": {}, }, "gpu_cache_usage_perc": { @@ -298,6 +309,12 @@ class MetricsManager: "description": "Token-level GPU prefix cache hit rate", "kwargs": {}, }, + "decode_preallocated_req_num": { + "type": Gauge, + "name": "fastdeploy:decode_preallocated_req_num", + "description": "Number of preallocated requests in decode instance", + "kwargs": {}, + }, } METRICS = { @@ -459,6 +476,18 @@ class MetricsManager: ], }, }, + "reschedule_req_num": { + "type": Counter, + "name": "fastdeploy:reschedule_req_num", + "description": "Total number of reschedule requests", + "kwargs": {}, + }, + "failed_recv_first_token_req_num": { + "type": Counter, + "name": "fastdeploy:failed_recv_first_token_req_num", + "description": "Total number of failed requests to receive the first token in decode", + "kwargs": {}, + }, } SPECULATIVE_METRICS = {} diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index d6a448a2693..effb8108422 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -146,8 +146,8 @@ class ForwardMeta: caches: Optional[list[paddle.Tensor]] = None # Flag of profile run is_dummy_or_profile_run: bool = False - # Routing Replay table buffer - routing_replay_table: Optional[paddle.Tensor] = None + # GPU transient routing buffer [max_num_batched_tokens, num_moe_layers, top_k] + device_routing_buffer: Optional[paddle.Tensor] = None # chunked MoE related moe_num_chunk: int = 1 @@ -159,6 +159,8 @@ class ForwardMeta: exist_prefill: bool = False position_ids: Optional[paddle.Tensor] = None + # for kvcache slot + slot_mapping: Optional[paddle.Tensor] = None real_bsz: int = 0 diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 526f55c2369..c04f137d10d 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -29,6 +29,7 @@ capture_custom_allreduce, custom_ar_clear_ipc_handles, ) +from fastdeploy.platforms import current_platform from fastdeploy.utils import get_logger logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log") @@ -123,9 +124,46 @@ def __init__( self.max_num_seqs = fd_config.scheduler_config.max_num_seqs self.real_bsz_to_captured_size = fd_config.graph_opt_config.real_bsz_to_captured_size - def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): + # Expected decode capture sequence (descending), consistent with capture_model() iteration order. + # Used to validate that captures happen in the correct order. + self._decode_expected_sequence: list[int] = sorted(self.cudagraph_capture_sizes, reverse=True) + # Points to the next expected position in _decode_expected_sequence. + self._decode_capture_index: int = 0 + + def _validate_decode_capture_order(self, shape: int) -> None: + """Validate that decode CUDA graph captures happen in expected descending order. + + Raises RuntimeError immediately if the actual capture order deviates from + the order defined by cudagraph_capture_sizes (sorted descending). + """ + if current_platform.is_xpu(): + return + + if self._decode_capture_index >= len(self._decode_expected_sequence): + raise RuntimeError( + f"[CUDA GRAPH][ID:{id(self)}] Unexpected CUDA graph capture: shape={shape}. " + f"All {len(self._decode_expected_sequence)} expected captures have already completed. " + f"Expected sequence: {self._decode_expected_sequence}" + ) + expected = self._decode_expected_sequence[self._decode_capture_index] + if shape != expected: + raise RuntimeError( + f"[CUDA GRAPH][ID:{id(self)}] CUDA graph capture order mismatch at index " + f"{self._decode_capture_index}: expected shape={expected}, got shape={shape}. " + f"Full expected sequence: {self._decode_expected_sequence}" + ) + logger.debug( + f"[CUDA GRAPH][ID:{id(self)}] Capture order validated: shape={shape} matches " + f"expected sequence at index {self._decode_capture_index} " + f"(sequence: {self._decode_expected_sequence})" + ) + self._decode_capture_index += 1 + + def run_static_model(self, entry: ConcreteSizeEntry, is_decode: bool = False, **kwargs): if not entry.captured: + if is_decode: + self._validate_decode_capture_order(entry.real_shape) # Warmup the model for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 @@ -194,13 +232,14 @@ def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor: # - Static full graph mode: Dynamic for prefill/mixed, Static + CUDAGraph for decode # - Dynamic mode: Dynamic + CUDAGraph for decode only if static_cudagraph_for_prefill or static_cudagraph_for_decode: - return self.run_static_model(entry, **kwargs) + return self.run_static_model(entry, is_decode=static_cudagraph_for_decode, **kwargs) # Capture a new cuda graph if entry.cuda_graph is None: assert ( real_shape == padding_real_shape ), f"real_shape:{real_shape} is not equal to padding_real_shape:{padding_real_shape} when capture new graph." + self._validate_decode_capture_order(padding_real_shape) # Warmup the model for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 @@ -278,6 +317,8 @@ def clear_graph(self): del self.concrete_size_entries paddle.device.cuda.empty_cache() + self._decode_capture_index = 0 + # Create new entrys self._create_entry_dict() diff --git a/fastdeploy/model_executor/graph_optimization/decorator.py b/fastdeploy/model_executor/graph_optimization/decorator.py index 562164aae1a..05ec79a495c 100644 --- a/fastdeploy/model_executor/graph_optimization/decorator.py +++ b/fastdeploy/model_executor/graph_optimization/decorator.py @@ -92,7 +92,7 @@ def forward(self, **kwargs): def __call__(self, **kwargs): return self.graph_opt_backend(**kwargs) - def clear_grpah_opt_backend(self, fd_config): + def clear_graph_opt_backend(self, fd_config): """ """ # TODO(gongshaotian): Resolve the bug of static graphs not being able to update weights assert ( diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 81eab7cce86..e003b1be089 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -73,6 +73,8 @@ def allocate_launch_related_buffer( num_heads, kv_num_heads, block_size, + head_dim=128, + dtype="bfloat16", ): # Initialize AttentionBackend buffers assert num_heads % kv_num_heads == 0 @@ -107,6 +109,28 @@ def allocate_launch_related_buffer( res["kv_batch_ids"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_tile_ids_per_batch"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + # Decode attention split ops buffers + if envs.USE_DECODE_UNIFIED_ATTENTION: + min_chunk_size = 512 + max_num_chunk = (max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 if decoder_step_token_num * group_size <= 16 else 32 + q_tile_num = (decoder_step_token_num * group_size + q_tile_size - 1) // q_tile_size + res["decode_block_indices"] = paddle.full( + [max_batch_size * kv_num_heads * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + res["decode_num_blocks"] = paddle.full([1], 0, dtype="int32") + res["decode_chunk_size"] = paddle.full([1], 0, dtype="int32") + res["decode_tmp_workspace"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads * head_dim], 0, dtype=dtype + ) + res["decode_tmp_m"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + res["decode_tmp_d"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + return res @@ -138,9 +162,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( - fd_config.model_config, "use_3d_rope", False - ) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime if fd_config.speculative_config.model_type != "main": self.rope_3d = False self.causal: bool = getattr(fd_config.model_config, "causal", True) @@ -212,7 +234,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -220,7 +246,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -308,7 +334,7 @@ def forward_mixed( # 64 is gpt-oss assert forward_meta.rotary_embs.shape[4] in [128, 32, 64] - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py index 4d6bbcdfb7d..88f28467569 100644 --- a/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/dsa_attention_backend.py @@ -28,6 +28,7 @@ if current_platform.is_cuda(): paddle.enable_compat(scope={"flash_mla"}) +from fastdeploy import envs from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, init_kv_signal_per_query, @@ -54,33 +55,6 @@ def yarn_get_mscale(scale=1, mscale=1): return 0.1 * mscale * math.log(scale) + 1.0 -def compute_slot_mapping( - block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req] - positions: paddle.Tensor, # [num_tokens] 每个token的位置 - batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求 - block_size: int, -) -> paddle.Tensor: - """ - 计算 slot_mapping - - 公式: slot = block_id * block_size + offset_in_block - """ - # 1. 计算每个 token 对应的 block 索引 - block_idx = positions // block_size # [num_tokens] - - # 2. 从 block_tables 中查表获取 block_id - # block_tables[batch_id_per_token, block_idx] - block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens] - - # 3. 计算在 block 内的偏移 - block_offset = positions % block_size # [num_tokens] - - # 4. 计算 slot_mapping - slot_mapping = block_ids * block_size + block_offset - - return slot_mapping.cast(paddle.int64) - - @dataclass class DSAAttentionMetadata(AttentionMetadata): """ @@ -136,7 +110,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method: str = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None @@ -270,7 +244,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -278,7 +256,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -331,7 +309,7 @@ def forward_mixed( # speculate_decoder = self.speculative_method is not None # speculate_max_tokens = self.speculate_max_draft_token_num - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -347,18 +325,11 @@ def forward_mixed( k_range = paddle.tensor(200.0) scale = paddle.abs(compressed_kv).max() / k_range - slot_mapping = compute_slot_mapping( - forward_meta.block_tables, - forward_meta.position_ids, - forward_meta.batch_id_per_token, - self.block_size, - ) - dsk_attn_write_cache( compressed_kv, k_pe, latent_cache, - slot_mapping, + forward_meta.slot_mapping, scale.cast(paddle.float32), "fp8_ds_mla", ) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index b51dce1449d..d5044dbb544 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -43,6 +43,9 @@ ) from fastdeploy.model_executor.layers.attention.ops import ( append_attention, + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, get_attn_mask_q, get_block_shape_and_split_kv_block, gqa_rope_write_cache, @@ -84,7 +87,7 @@ def init_flash_attn_version(): sm_version = get_sm_version() if sm_version >= 100: try: - paddle.compat.enable_torch_proxy(scope={"cutlass"}) + paddle.enable_compat(scope={"cutlass"}) from flash_mask.cute.interface import flashmask_attention as fa4 global flashmask_attention_v4 @@ -95,7 +98,7 @@ def init_flash_attn_version(): logger.info(f"The current platform[sm{get_sm_version()}] can't import Flash Attention V4.") if FLASH_ATTN_VERSION is None: - if sm_version >= 89 and any(num >= 89 for num in paddle.version.cuda_archs()): + if sm_version == 90 and 90 in paddle.version.cuda_archs(): FLASH_ATTN_VERSION = 3 logger.info("The current platform supports Flash Attention V3.") else: @@ -267,15 +270,15 @@ def __init__( self.rank, self.device_id = init_rank_and_device_id(fd_config) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( - fd_config.model_config, "use_3d_rope", False - ) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime if fd_config.speculative_config.model_type != "main": self.rope_3d = False # Note(ZKK): here must be consistent with append_attn_backend.py self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024)) + self.max_tokens_per_batch: int = self.speculate_max_draft_token_num + 1 if FLASH_ATTN_VERSION is None: init_flash_attn_version() + print(f"num_heads: {self.num_heads}, kv_num_heads: {self.kv_num_heads}") def get_attention_meta(self): """get_attention_meta""" @@ -301,7 +304,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -309,7 +316,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -336,7 +343,7 @@ def forward_mixed( ): metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -416,6 +423,20 @@ def forward_mixed( ) else: forward_meta.attn_mask_q = None + if envs.USE_DECODE_UNIFIED_ATTENTION: + config_for_attention( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + getattr(layer, "cache_quant_type_str", "none"), + self.group_size, + self.kv_num_heads, + self.max_tokens_per_batch, + ) use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0 @@ -470,73 +491,148 @@ def forward_mixed( head_dim=self.head_dim, )[0].reshape([-1, self.attn_outputsize_tp]) - res_decoder = append_attention( - qkv, - cache_k, - cache_v, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - forward_meta.block_tables, - forward_meta.encoder_batch_ids, - forward_meta.encoder_tile_ids_per_batch, - forward_meta.encoder_num_blocks_x_cpu, - forward_meta.kv_batch_ids, - forward_meta.kv_tile_ids_per_batch, - forward_meta.kv_num_blocks_x_cpu, - forward_meta.decoder_batch_ids, - forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_cpu, - forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, - forward_meta.rotary_embs, - forward_meta.attn_mask, - layer.qkv_bias, - layer.qkv_scale, - cache_k_scales, - cache_v_scales, - getattr(layer, "cache_k_out_scale", None), - getattr(layer, "cache_v_out_scale", None), - getattr(layer, "cache_k_zp", None), - getattr(layer, "cache_v_zp", None), - layer.linear_shift, - layer.linear_smooth, - forward_meta.attn_mask_offsets, - metadata.kv_signal_data_list[layer.layer_id], - q_norm_weight, - k_norm_weight, - getattr(layer, "sinks", None), - getattr(layer, "rms_norm_eps", 1e-6), - metadata._fuse_kernel_compute_dtype, - getattr(layer, "cache_quant_type_str", "none"), - layer.use_neox_rotary_style, - self.rope_3d, - self.max_seq_len, - getattr(layer, "quant_max_bound", 0.0), - getattr(layer, "quant_min_bound", 0.0), - getattr(layer, "out_scale", -1.0), - self.encoder_block_shape_q, - self.decoder_block_shape_q, - self.max_partition_size, - self.max_seq_len, - self.speculate_max_draft_token_num + 1, - self.causal, - self.speculative_method is not None, - ) - - if use_fa_do_prefill: - merge_prefill_decode_output( - res_encoder, - res_decoder, + if envs.USE_DECODE_UNIFIED_ATTENTION: + qkv_out = decoder_write_cache_with_rope( + qkv, + cache_k, + cache_v, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - self.num_heads, - self.head_dim, + forward_meta.block_tables, + forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + layer.qkv_bias, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "rms_norm_eps", 1e-6), + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + self.speculative_method is not None, + ) + if use_fa_do_prefill: + res_decoder = res_encoder + else: + res_decoder = paddle.empty( + [qkv.shape[0], self.num_heads * self.head_dim], + dtype=qkv.dtype, + ) + decode_unified_attention( + qkv_out, + cache_k, + cache_v, + forward_meta.decode_tmp_workspace, + forward_meta.decode_tmp_m, + forward_meta.decode_tmp_d, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + forward_meta.attn_mask, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + forward_meta.attn_mask_offsets, + getattr(layer, "sinks", None), + res_decoder, + getattr(layer, "cache_quant_type_str", "none"), + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), self.speculate_max_draft_token_num + 1, + self.causal, ) - return res_encoder - else: return res_decoder + else: + res_decoder = append_attention( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + forward_meta.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + forward_meta.attn_mask_offsets, + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "sinks", None), + getattr(layer, "rms_norm_eps", 1e-6), + metadata._fuse_kernel_compute_dtype, + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + getattr(layer, "out_scale", -1.0), + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.max_partition_size, + self.max_seq_len, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + ) + + if use_fa_do_prefill: + merge_prefill_decode_output( + res_encoder, + res_decoder, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_heads, + self.head_dim, + self.speculate_max_draft_token_num + 1, + ) + return res_encoder + else: + return res_decoder diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 35d27504ab5..4c663c1a702 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -22,6 +22,7 @@ import paddle +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -121,9 +122,7 @@ def __init__( self.rank, self.device_id = init_rank_and_device_id(fd_config) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( - fd_config.model_config, "use_3d_rope", False - ) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime if fd_config.speculative_config.model_type != "main": self.rope_3d = False self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) @@ -148,7 +147,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # metadata only save pd_disaggregation info. metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -156,7 +159,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -199,7 +202,7 @@ def forward_mixed( cache_k_scales = getattr(layer, "cache_k_scale", None) cache_v_scales = getattr(layer, "cache_v_scale", None) - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -309,7 +312,7 @@ def forward_mixed( q, k, v, - forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_q[: forward_meta.attn_cu_seqlens_k.shape[0]], forward_meta.attn_cu_seqlens_k, forward_meta.seq_lens_encoder, res_encoder, diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 61ccc4e16e7..c2a0c9e6a2d 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -18,7 +18,7 @@ import paddle -paddle.enable_compat(scope={"flash_mla"}) # Enable torch proxy before importing flash_mla +paddle.enable_compat(scope={"flash_mla"}) # Enable paddle.enable_compat before importing flash_mla import math import os from dataclasses import dataclass, field @@ -34,6 +34,7 @@ logger.debug(f"flash_attention_v3_varlen not available: {e}") flash_attention_v3_varlen = None +from fastdeploy import envs from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, init_kv_signal_per_query, @@ -263,7 +264,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None @@ -358,7 +359,11 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): # pd_disaggregation metadata.kv_signal_data_list = [None] * self.num_layers if self.pd_disaggregation_mode == "per_chunk": - if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + if ( + not self.keep_pd_step_flag + and not forward_meta.is_dummy_or_profile_run + and not envs.FD_PD_TRANSFER_VIA_STORAGE + ): init_kv_signal_per_query( forward_meta.seq_lens_encoder, forward_meta.seq_lens_this_time, @@ -366,7 +371,7 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.rank, self.num_layers + self.num_layers_draft_model, ) - elif self.pd_disaggregation_mode == "per_query": + elif self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_metadata = open_shm_and_get_meta_signal( self.rank, int(self.device_id), self.keep_pd_step_flag ) @@ -405,7 +410,7 @@ def forward_extend( """ metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -459,7 +464,7 @@ def forward_decode( """ metadata = self.attention_metadata - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, @@ -549,7 +554,7 @@ def forward_mixed( speculate_decoder = self.speculative_method is not None speculate_max_tokens = self.speculate_max_draft_token_num - if self.pd_disaggregation_mode == "per_query": + if self.pd_disaggregation_mode == "per_query" and not envs.FD_PD_TRANSFER_VIA_STORAGE: metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index, diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index e0175573fa3..d5d6c45afa7 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -15,6 +15,9 @@ """ from .append_attention import append_attention, append_attention_with_output +from .config_for_attention import config_for_attention +from .decode_unified_attention import decode_unified_attention +from .decoder_write_cache_with_rope import decoder_write_cache_with_rope from .flash_attn_v4 import flash_attn_v4 from .flash_mask_attention import flash_mask_attention from .get_attn_mask_q import get_attn_mask_q @@ -37,4 +40,7 @@ "flash_attn_v4", "flash_mask_attention", "get_attn_mask_q", + "config_for_attention", + "decoder_write_cache_with_rope", + "decode_unified_attention", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py new file mode 100644 index 00000000000..d8226aad4b1 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py @@ -0,0 +1,58 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + config_for_attention as config_for_attention_cuda, + ) + + +def config_for_attention( + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + max_len_tensor_cpu: paddle.Tensor, + cache_quant_type: str = "none", + group_size: int = 1, + kv_num_heads: int = 1, + max_tokens_per_batch: int = 1, +): + """ + append_attention + """ + if current_platform.is_cuda(): + config_for_attention_cuda( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_indices, + num_blocks, + chunk_size, + max_len_tensor_cpu, + cache_quant_type, + group_size, + kv_num_heads, + max_tokens_per_batch, + ) + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py b/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py new file mode 100644 index 00000000000..fedfc33dc7c --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py @@ -0,0 +1,105 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decode_unified_attention as decode_unified_attention_cuda, + ) + + +def decode_unified_attention( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + tmp_workspace: paddle.Tensor, + tmp_m: paddle.Tensor, + tmp_d: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + set_max_lengths: paddle.Tensor, + attn_mask: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + mask_offset: Optional[paddle.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, + fmha_out: Optional[paddle.Tensor] = None, + cache_quant_type: str = "none", + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + max_tokens_per_batch: int = 1, + causal: bool = True, + sliding_window: int = 0, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + out = decode_unified_attention_cuda( + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + set_max_lengths, + attn_mask, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + mask_offset, + sinks, + fmha_out, + cache_quant_type, + max_input_length, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + causal, + sliding_window, + ) + return out + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py new file mode 100644 index 00000000000..b10f6cd1bf6 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py @@ -0,0 +1,97 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decoder_write_cache_with_rope as decoder_write_cache_with_rope_cuda, + ) + + +def decoder_write_cache_with_rope( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + set_max_lengths: paddle.Tensor, + rotary_embs: Optional[paddle.Tensor] = None, + qkv_bias: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + kv_signal_data: Optional[paddle.Tensor] = None, + q_norm_weight: Optional[paddle.Tensor] = None, + k_norm_weight: Optional[paddle.Tensor] = None, + rms_norm_eps: float = 1e-6, + cache_quant_type: str = "none", + use_neox_rotary_style: bool = False, + rope_3d: bool = False, + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + speculate_decoder: bool = False, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + qkv_out = decoder_write_cache_with_rope_cuda( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + set_max_lengths, + rotary_embs, + qkv_bias, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + cache_quant_type, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + speculate_decoder, + ) + return qkv_out + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py b/fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py index 092912149a9..d01973f80d0 100644 --- a/fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/iluvatar/attention/mha_attn_backend.py @@ -89,7 +89,7 @@ def __init__( # note: scale need to change if using MLA self.scale = 1.0 / sqrt(head_dim) self.dtype = paddle.get_default_dtype() - self.enable_mm = fd_config.model_config.enable_mm + self.enable_mm = fd_config.enable_mm_runtime self.rope_batch_stride = self.max_context_len * self.head_dim if self.enable_mm else 0 if "paddleocr" in fd_config.model_config.model_type: self.is_interleaved_rope_mode = False diff --git a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py index 82938c87367..bd2d8505228 100644 --- a/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py @@ -219,7 +219,7 @@ def __init__( self.block_size = llm_config.cache_config.block_size self.max_seq_len = llm_config.model_config.max_model_len self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta - self.rope_3d = getattr(llm_config.model_config, "rope_3d", False) + self.rope_3d = llm_config.enable_rope_3d_runtime self.causal = getattr(llm_config.model_config, "causal", True) self.speculative_method = llm_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py index 74fc27f67b4..0fd3553fda9 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -101,7 +101,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None @@ -128,7 +128,7 @@ def __init__( fd_config.parallel_config.expert_parallel_rank = 0 self.rank, self.device_id = init_rank_and_device_id(fd_config) - self.enable_mm = fd_config.model_config.enable_mm + self.enable_mm = fd_config.enable_mm_runtime self.model_type = fd_config.model_config.model_type self.is_neox_style = False if "paddleocr" in fd_config.model_config.model_type: diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py index c905086f9f7..dcd5589f0d8 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -105,7 +105,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method = fd_config.speculative_config.method self.use_speculate: bool = self.speculative_method is not None diff --git a/fastdeploy/model_executor/layers/backends/xpu/attention.py b/fastdeploy/model_executor/layers/backends/xpu/attention.py index 85565d33efb..31fce9bdf51 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/attention.py +++ b/fastdeploy/model_executor/layers/backends/xpu/attention.py @@ -88,9 +88,7 @@ def __init__( self.rope_theta: float = ( 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta ) - self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( - fd_config.model_config, "use_3d_rope", False - ) + self.rope_3d: bool = fd_config.enable_rope_3d_runtime self.causal: bool = getattr(fd_config.model_config, "causal", True) self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) diff --git a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py index c0df764c07c..21f2c2dedc4 100644 --- a/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py +++ b/fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py @@ -805,9 +805,9 @@ def enable_batch_invariant_mode(): if _batch_invariant_MODE: return - if hasattr(paddle, "compat") and hasattr(paddle.compat, "enable_torch_proxy"): - paddle.compat.enable_torch_proxy() - # TODO(liujundong): Enabling torch proxy here has a global effect. + if hasattr(paddle, "enable_compat"): + paddle.enable_compat() + # TODO(liujundong): Enabling paddle.enable_compat() here has a global effect. # Do NOT call this function from module import time, # otherwise it may affect other test cases during pytest collection. # (ex: Could not import module 'PretrainedTokenizer' or No module named 'paddle.distributed.tensor') diff --git a/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py new file mode 100644 index 00000000000..7f27b52975d --- /dev/null +++ b/fastdeploy/model_executor/layers/flashinfer_comm_fusion.py @@ -0,0 +1,209 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional, Tuple + +import paddle +import paddle.distributed as dist + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.utils import has_flashinfer +from fastdeploy.utils import get_logger + +logger = get_logger("flashinfer", "flashinfer.log") + +_flashinfer_comm = None +_workspace_manager = None + + +def _get_flashinfer_comm(): + """Lazily import flashinfer.comm to avoid side effects at module load time.""" + global _flashinfer_comm + if _flashinfer_comm is not None: + return _flashinfer_comm + if has_flashinfer(): + try: + with paddle.use_compat_guard(enable=True, scope={"flashinfer"}): + import flashinfer.comm as comm + + _flashinfer_comm = comm + except ImportError: + logger.warning("flashinfer.comm is not available, falling back to standard " "implementation") + return _flashinfer_comm + + +class FlashInferWorkspaceManager: + def __init__(self): + self.workspace_tensor = None + self.ipc_handles = None + self.world_size = None + self.rank = None + self.initialized = False + + def initialize( + self, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + group=None, + use_fp32_lamport: bool = False, + ): + """Initialize workspace""" + if self.initialized and self.world_size == world_size: + return + + comm = _get_flashinfer_comm() + if comm is None: + logger.warning("FlashInfer comm not available, skipping workspace " "initialization") + return + + self.cleanup() + + self.ipc_handles, self.workspace_tensor = comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + max_token_num, + hidden_dim, + group=group, + use_fp32_lamport=use_fp32_lamport, + ) + + self.world_size = world_size + self.rank = rank + self.initialized = True + + logger.info(f"FlashInfer workspace initialized for rank {rank}, " f"world_size {world_size}") + + def cleanup(self): + """Clean up workspace""" + if self.initialized and self.ipc_handles is not None: + try: + comm = _get_flashinfer_comm() + if comm is not None: + comm.trtllm_destroy_ipc_workspace_for_all_reduce(self.ipc_handles, group=dist.get_group()) + except Exception as e: + logger.warning(f"Failed to cleanup FlashInfer workspace: {e}") + finally: + self.workspace_tensor = None + self.ipc_handles = None + self.initialized = False + + +_workspace_manager = FlashInferWorkspaceManager() + + +def ensure_workspace_initialized( + fd_config: FDConfig, max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False +): + """Ensure workspace is initialized""" + comm = _get_flashinfer_comm() + if not has_flashinfer() or comm is None: + return False + + assert fd_config is not None + world_size = fd_config.parallel_config.tensor_parallel_size + if world_size <= 1: + return False + + rank = dist.get_rank() + + if not _workspace_manager.initialized or _workspace_manager.world_size != world_size: + _workspace_manager.initialize( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + use_fp32_lamport=use_fp32_lamport, + ) + + return _workspace_manager.initialized + + +def flashinfer_allreduce_residual_rmsnorm( + fd_config: FDConfig, + input_tensor: paddle.Tensor, + residual: paddle.Tensor, + weight: paddle.Tensor, + eps: float = 1e-6, + max_token_num: int = 2048, + use_oneshot: Optional[bool] = None, + trigger_completion_at_end: bool = False, + fp32_acc: bool = False, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + Use FlashInfer's fused allreduce + residual + RMS norm operation + """ + comm = _get_flashinfer_comm() + if not has_flashinfer() or comm is None: + logger.debug("FlashInfer not available, falling back to standard " "implementation") + return None, None + + assert fd_config is not None + world_size = fd_config.parallel_config.tensor_parallel_size + if world_size <= 1: + logger.debug("Single GPU, no need for allreduce fusion") + return None, None + + assert input_tensor.shape[0] <= max_token_num + + if not ensure_workspace_initialized( + fd_config=fd_config, + max_token_num=max_token_num, + hidden_dim=input_tensor.shape[-1], + use_fp32_lamport=(input_tensor.dtype == paddle.float32), + ): + logger.debug("FlashInfer workspace not available") + return None, None + + token_num, hidden_dim = input_tensor.shape + + residual_out = paddle.empty_like(residual) + norm_out = paddle.empty_like(input_tensor) + # support empty tensor + if input_tensor.shape[0] == 0: + return norm_out, residual_out + comm.trtllm_allreduce_fusion( + allreduce_in=input_tensor, + world_size=world_size, + world_rank=dist.get_rank(), + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=_workspace_manager.workspace_tensor, + launch_with_pdl=True, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=(comm.AllReduceFusionPattern.kARResidualRMSNorm), + allreduce_out=None, + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + quant_out=None, + scale_out=None, + rms_gamma=weight, + rms_eps=eps, + scale_factor=None, + layout_code=None, + ) + + return norm_out, residual_out + + +def cleanup_flashinfer_workspace(): + global _workspace_manager + if _workspace_manager is not None: + _workspace_manager.cleanup() diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 2bee885ff43..b9138adad06 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -82,10 +82,17 @@ def process_loaded_weights(self, layer, weights) -> None: layer.weight.set_value(weights) def apply(self, layer: nn.Layer, x: paddle.Tensor) -> paddle.Tensor: - linear_out = paddle.matmul(x, layer.weight) if layer.with_bias: - linear_out = paddle.add(linear_out, layer.bias) - return linear_out + bias = layer.bias + assert bias.dim() == 1 and bias.shape[-1] == layer.weight.shape[-1], ( + f"bias must be 1D with size equal to the last dim of weight, " + f"but got bias.shape={bias.shape}, weight.shape[-1]={layer.weight.shape[-1]}" + ) + out = paddle.nn.functional.linear(x, layer.weight, bias) + else: + out = paddle.matmul(x, layer.weight) + + return out class LinearBase(nn.Layer): @@ -846,6 +853,7 @@ def __init__( skip_quant: bool = False, weight_dtype: str = "", layer_id: int = -1, + enable_all_reduce_fusion: bool = None, ): """ Initialize a linear layer with additional parameters for inference and quantization. @@ -857,9 +865,17 @@ def __init__( input_size (int): Number of input features. Defaults to None. output_size (int): Number of output features. Defaults to None. with_bias (bool): Whether to include bias or not. Defaults to False. - skip_quant (bool): Whether to skip quantization. Defaults to False. + skip_quant (bool): Whether to skip quantization or not. Defaults to False. + enable_all_reduce_fusion (bool, optional): Whether to enable all-reduce fusion. + If None, it is determined by the config flag and prefix. Defaults to None. """ self.fd_config = fd_config + if enable_all_reduce_fusion is None: + self.enable_all_reduce_fusion = False + else: + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and enable_all_reduce_fusion + ) self.ep_size = fd_config.parallel_config.expert_parallel_size self.tp_size = fd_config.parallel_config.tensor_parallel_size self.tp_group = fd_config.parallel_config.tp_group @@ -937,7 +953,10 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: out = self.quant_method.apply(self, x) - if self.reduce_results and self.tp_size > 1: + need_tp_all_reduce = ( + self.reduce_results and self.tp_size > 1 and not (self.enable_all_reduce_fusion and out.shape[0] <= 2048) + ) + if need_tp_all_reduce: out = tensor_model_parallel_all_reduce(out, self.tp_group) return out diff --git a/fastdeploy/model_executor/layers/moe/__init__.py b/fastdeploy/model_executor/layers/moe/__init__.py index 540a0828ae5..7f2ded19cb6 100644 --- a/fastdeploy/model_executor/layers/moe/__init__.py +++ b/fastdeploy/model_executor/layers/moe/__init__.py @@ -17,7 +17,7 @@ CutlassW4AFP8MoEMethod, CutlassWeightOnlyMoEMethod, ) -from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod +from .fused_moe_triton_backend import TritonMoEMethod, TritonWeightOnlyMoEMethod from .moe import FusedMoE __all__ = [ @@ -26,4 +26,5 @@ CutlassW4AFP8MoEMethod, FusedMoE, TritonWeightOnlyMoEMethod, + TritonMoEMethod, ] diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 243567a422f..967c2a2fd02 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import inspect import traceback from abc import abstractmethod from types import ModuleType @@ -26,6 +27,7 @@ import fastdeploy from fastdeploy import envs from fastdeploy.config import MoEPhase +from fastdeploy.platforms import current_platform from fastdeploy.utils import singleton @@ -39,10 +41,13 @@ def load_deep_ep() -> ModuleType: try: if envs.FD_USE_PFCC_DEEP_EP: - # Enable torch proxy before importing deep_ep (required by PFCC/PaddleFleet variants) - paddle.compat.enable_torch_proxy(scope={"deep_ep"}) + # Enable paddle.enable_compat before importing deep_ep (required by PFCC/PaddleFleet variants) + paddle.enable_compat(scope={"deep_ep"}) try: - import paddlefleet.ops.deep_ep as deep_ep # type: ignore + try: + import paddlefleet.ops.deep_ep as deep_ep # type: ignore + except: + import paddlefleet_ops.deep_ep as deep_ep # type: ignore logger.info("FD use PaddleFleet/DeepEP now.") return deep_ep @@ -509,6 +514,7 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): expert_in_rank_num_list=expert_in_rank_num_list, tokens_per_expert_stats_list=tokens_per_expert_stats_list, redundant_ep_rank_num_plus_one=layer.fd_config.eplb_config.redundant_experts_num + 1, + topk_reduce_func=getattr(layer, "topk_reduce_func", None), ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_redundant_topk_select( @@ -526,6 +532,9 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): if layer.topk_method == "noaux_tc": from fastdeploy.model_executor.layers.moe.moe import get_moe_scores + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) score, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -534,6 +543,8 @@ def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor): layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) else: topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( @@ -600,6 +611,8 @@ def __init__( use_internode_ll_two_stage=use_internode_ll_two_stage, ) self.num_worst_tokens = prefill_num_worst_tokens + self._dispatch_parameters: Optional[set] = None + self._combine_parameters: Optional[set] = None logger.info(f"prefill_num_worst_tokens {prefill_num_worst_tokens}") def set_allocate_on_comm_stream(allocate_on_comm_stream: bool = False): @@ -654,8 +667,12 @@ def dispatch( } if envs.FD_USE_PFCC_DEEP_EP: - dispatch_args["num_worst_tokens"] = self.num_worst_tokens - dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0 + if self._dispatch_parameters is None: + self._dispatch_parameters = set(inspect.signature(buffer.dispatch).parameters) + if "num_worst_tokens" in self._dispatch_parameters: + dispatch_args["num_worst_tokens"] = self.num_worst_tokens + if "skip_x_record_stream" in self._dispatch_parameters: + dispatch_args["skip_x_record_stream"] = self.num_worst_tokens > 0 return buffer.dispatch(**dispatch_args) @@ -681,7 +698,10 @@ def combine( } if envs.FD_USE_PFCC_DEEP_EP: - combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0 + if self._combine_parameters is None: + self._combine_parameters = set(inspect.signature(buffer.combine).parameters) + if "skip_x_record_stream" in self._combine_parameters: + combine_args["skip_x_record_stream"] = self.num_worst_tokens > 0 fused_moe_out, _, event = buffer.combine(**combine_args) return fused_moe_out, event diff --git a/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py b/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py index b449c246cc0..26c729cbf34 100644 --- a/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py +++ b/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py @@ -18,7 +18,7 @@ import paddle -paddle.compat.enable_torch_proxy(scope={"flashinfer"}) +paddle.enable_compat(scope={"flashinfer"}) def _dtype_str(dtype) -> str: diff --git a/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..44d7e54ae88 --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/fused_cast_sigmoid_bias.py @@ -0,0 +1,73 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +_FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR = None + +try: + from fastdeploy.model_executor.ops.gpu import ( + fused_cast_sigmoid_bias as _fused_cast_sigmoid_bias_cuda, + ) +except ImportError as e: + _fused_cast_sigmoid_bias_cuda = None + _FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR = e + + +def is_available() -> bool: + """Return whether the fused GPU custom op is available.""" + return _fused_cast_sigmoid_bias_cuda is not None + + +def fused_cast_sigmoid_bias( + gate_out: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, + cast_type: str = "float32", +) -> tuple: + """ + Fused operation: cast gate_out to the specified type, apply sigmoid, and add bias. + + This function fuses the following three separate operations: + 1. gate_out = gate_out.cast(cast_type) + 2. scores = sigmoid(gate_out) + 3. scores_with_bias = scores + e_score_correction_bias + + Args: + gate_out: [num_tokens, num_experts], bf16/fp16/fp32 dtype - raw gate output + e_score_correction_bias: [num_experts], fp32 dtype - correction bias + cast_type: output dtype string, supports "float32", "float16", "bfloat16" + + Returns: + scores: [num_tokens, num_experts], cast_type dtype - result of sigmoid(gate_out) + scores_with_bias: [num_tokens, num_experts], cast_type dtype - scores with bias added + + Precision: + All intermediate computations (cast, sigmoid, bias addition) are performed + in float32 precision; conversion to cast_type happens only at the final store. + When cast_type is "float32", the result is bit-exact with the following + reference implementation: + gate_fp32 = gate_out.cast("float32") + scores = sigmoid(gate_fp32) + scores_with_bias = scores + bias + When cast_type is "float16"/"bfloat16", the only precision loss comes from + the final type conversion, equivalent to calling .cast(cast_type) after + computing in float32. + """ + if _fused_cast_sigmoid_bias_cuda is None: + raise ImportError( + "fused_cast_sigmoid_bias is not available. " "Please ensure the GPU custom ops are compiled." + ) from _FUSED_CAST_SIGMOID_BIAS_IMPORT_ERROR + return _fused_cast_sigmoid_bias_cuda(gate_out, e_score_correction_bias, cast_type) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 0c86270c630..dbd0679ebcd 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -28,7 +28,11 @@ from .fused_moe_backend_base import UnquantizedFusedMoEMethod if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce + from fastdeploy.model_executor.ops.gpu import ( + count_tokens_per_expert_func, + moe_expert_dispatch, + moe_expert_reduce, + ) try: from fastdeploy.model_executor.ops.gpu import ( @@ -39,6 +43,7 @@ logger.warning("import w4afp8_gemm_scale_permute Failed!") from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.quantization.fp8_utils import paddlefleet_ops from fastdeploy.model_executor.utils import ( TensorTracker, free_tensor, @@ -48,6 +53,12 @@ ) +def m_grouped_bf16_gemm_nn_contiguous(x, y, expert_idx_per_token): + out = paddle.empty([x.shape[0], y.shape[-1]], dtype=x.dtype) + paddlefleet_ops.deep_gemm.m_grouped_bf16_gemm_nn_contiguous(x, y, out, expert_idx_per_token) + return out + + class CutlassMoEMethod(UnquantizedFusedMoEMethod): """ Use Cutlass Group Gemm to compute Fused MoE. @@ -126,6 +137,7 @@ def apply_ep_prefill( # 1. Select topk experts and weights topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) # 2. EP Dispatch + dispatch_kwargs = {"expert_alignment": 128} if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE else {} ( recv_x, recv_topk_idx, @@ -133,7 +145,7 @@ def apply_ep_prefill( recv_num_tokens_per_expert_list, handle, event, - ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights) + ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, **dispatch_kwargs) if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_idx) @@ -146,54 +158,107 @@ def apply_ep_prefill( # 3. Compute ffn if token_all_num > 0: logger.debug(f"token_all_num {token_all_num}") - ( - permute_input, - permute_indices_per_token, - recv_num_tokens_per_expert_list_cumsum, - dst_weights, - dst_indices, - cumsum_idx_gpu, - expert_idx_per_token, - dequant_scale, - ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch( - recv_x, - recv_topk_idx, - recv_topk_weights, - (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), - recv_num_tokens_per_expert_list, - token_all_num, - self.moe_quant_type, - ) - if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": - # only w4a8 and w4afp8 need expert_idx_per_token - # Other need not this tensor, so we make it None. - expert_idx_per_token = None + + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": + # --- moe_permute / moe_unpermute path --- + recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32) + (permute_input, permute_indices_per_token, dst_weights, _scale_out, m_indices) = ( + paddle.nn.functional.moe_permute( + hidden_states=recv_x, + scale=None, + expert_routemap_topk=recv_topk_idx_i32, + expert_prob_topk=recv_topk_weights, + num_experts=layer.num_local_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=token_all_num, + return_expert_indices=True, + ) + ) + + if fastdeploy.envs.FD_USE_DEEP_GEMM: + out = m_grouped_bf16_gemm_nn_contiguous( + permute_input, getattr(layer, self.added_weight_attrs[0]), m_indices + ) + else: + out = paddle.incubate.nn.functional.batched_gemm( + permute_input, + getattr(layer, self.added_weight_attrs[0]), + recv_num_tokens_per_expert_list, + ) + + if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: + out = paddlefleet_ops.fused_swiglu_scale(out, dst_weights) + else: + out = paddle.incubate.nn.functional.swiglu(out) + + if fastdeploy.envs.FD_USE_DEEP_GEMM: + ffn_out = m_grouped_bf16_gemm_nn_contiguous( + out, getattr(layer, self.added_weight_attrs[1]), m_indices + ) + else: + ffn_out = paddle.incubate.nn.functional.batched_gemm( + out, + getattr(layer, self.added_weight_attrs[1]), + recv_num_tokens_per_expert_list, + ) + + tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute( + hidden_states_unzipped=ffn_out, + zipped_expertwise_rowmap=permute_indices_per_token, + expert_routemap_topk=recv_topk_idx_i32, + token_prob_unzipped=dst_weights, + total_zipped_tokens=recv_x.shape[0], + num_experts=layer.num_local_experts, + using_weighted_combine=not fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE, + ) else: - expert_idx_per_token = expert_idx_per_token.cast("int64") + # --- original ep_moe_expert_dispatch / combine path --- + ( + permute_input, + permute_indices_per_token, + recv_num_tokens_per_expert_list_cumsum, + dst_weights, + dst_indices, + cumsum_idx_gpu, + expert_idx_per_token, + dequant_scale, + ) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch( + recv_x, + recv_topk_idx, + recv_topk_weights, + (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), + recv_num_tokens_per_expert_list, + token_all_num, + self.moe_quant_type, + ) + if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": + expert_idx_per_token = None + else: + expert_idx_per_token = expert_idx_per_token.cast("int64") - if hasattr(layer, "up_gate_proj_in_scale"): - dequant_scale = None + if hasattr(layer, "up_gate_proj_in_scale"): + dequant_scale = None - ffn_out = self.compute_ffn( - layer, - permute_input, - recv_num_tokens_per_expert_list_cumsum, - expert_idx_per_token, - False, - -1, - dequant_scale, - ) + ffn_out = self.compute_ffn( + layer, + permute_input, + recv_num_tokens_per_expert_list_cumsum, + expert_idx_per_token, + False, + -1, + dequant_scale, + ) - # prmt back per rank - tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( - ffn_out, - dst_weights, - permute_indices_per_token, - dst_indices, - None, # down_proj_bias, - False, # norm_topk_prob - 1.0, - ) + tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine( + ffn_out, + dst_weights, + permute_indices_per_token, + dst_indices, + None, # down_proj_bias, + False, # norm_topk_prob + 1.0, + ) else: tmp_ffn_out = recv_x @@ -275,8 +340,82 @@ def apply_tp( Paddle Cutlass compute Fused MoE. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") + if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16": + if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") + gate_out, topk_weights, topk_idx = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + use_fused_cast=use_fused, + ) + else: + gate_out = gate_out.cast("float32") + topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) + topk_idx_i32 = topk_idx.astype(paddle.int32) + override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1) + (permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap + paddle.nn.functional.moe_permute( + hidden_states=x, + scale=None, + expert_routemap_topk=topk_idx_i32, + expert_prob_topk=topk_weights, + num_experts=layer.num_experts, + tokens_per_expert=[], + padding_alignment=128, + override_buffer_size=override_buffer_size, + ) + ) + + # Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert. + token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast( + paddle.int64 + ) + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_idx) + + ffn_out = self.compute_ffn( + layer, + permute_input, + token_nums_per_expert_cumsum, + None, # expert_idx_per_token not needed for w16a16 without bias + False, + -1, + None, # dequant_scale + None, # max_tokens_per_expert + ) + + fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute( + hidden_states_unzipped=ffn_out, + zipped_expertwise_rowmap=permute_indices_per_token, + expert_routemap_topk=topk_idx_i32, + token_prob_unzipped=dst_weights, + total_zipped_tokens=x.shape[0], + num_experts=layer.num_experts, + using_weighted_combine=True, + ) + return fused_moe_out + if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_idx = get_moe_scores( gate_out, layer.n_group, @@ -285,7 +424,10 @@ def apply_tp( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, ) + ( permute_input, token_nums_per_expert, @@ -308,6 +450,7 @@ def apply_tp( topk_only_mode=True, ) else: + gate_out = gate_out.cast("float32") ( permute_input, token_nums_per_expert, @@ -340,7 +483,6 @@ def apply_tp( expert_idx_per_token = None else: expert_idx_per_token = expert_idx_per_token.cast("int64") - ffn_out = self.compute_ffn( layer, permute_input, @@ -362,7 +504,6 @@ def apply_tp( norm_topk_prob=False if layer.topk_method == "noaux_tc" else True, routed_scaling_factor=1.0, ) - return fused_moe_out diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 135fb5ecafc..adeac2f2cf4 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -188,7 +188,7 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op( else: ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( ffn_out, - using_pow2_scale=not disable_ue8m0_cast, + using_pow2_scale=not disable_ue8m0_cast or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, using_ue8m0_scale=not disable_ue8m0_cast, ) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] @@ -207,67 +207,6 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op( return ffn_out -def moe_topk_select( - gating_output: paddle.Tensor, - n_group: int, - topk_group: int, - top_k: int, - routed_scaling_factor: float, - e_score_correction_bias: paddle.Tensor, - renormalize: bool = False, -): - """ - Topk selection using paddle PHI topk API. - - Args: - gating_output: gate output logits, shape [seq_len, n_experts] - n_group: number of expert groups - topk_group: number of top-k groups to select - top_k: number of top experts per token - routed_scaling_factor: scaling factor for routed experts - e_score_correction_bias: bias for expert selection - renormalize: whether to renormalize topk probabilities - - Returns: - topk_weights: normalized topk probabilities, shape [seq_len, top_k] - topk_ids: topk expert indices, shape [seq_len, top_k] - """ - # compute gate probs via sigmoid - gate_probs = paddle.nn.functional.sigmoid(gating_output) - # probs_for_choice includes correction bias for topk selection - probs_for_choice = gate_probs + e_score_correction_bias if e_score_correction_bias is not None else gate_probs - # group-based topk selection - n_group = n_group if n_group > 0 else 1 - topk_group = topk_group if topk_group > 0 else 1 - if n_group > 1 and topk_group < n_group: - seq_length, n_experts = probs_for_choice.shape - group_scores = ( - probs_for_choice.reshape([seq_length, n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) - ) # [seq_len, n_group] - group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [seq_len, topk_group] - group_mask = paddle.sum( - paddle.nn.functional.one_hot(group_idx, num_classes=n_group).cast(group_scores.dtype), - axis=1, # Sum over topk_group dimension -> [seq_len, n_group] - ) - score_mask = ( - group_mask.unsqueeze(-1).expand([seq_length, n_group, n_experts // n_group]).reshape([seq_length, -1]) - ) # [seq_len, n_experts] - probs_for_choice = probs_for_choice.masked_fill(~score_mask.astype(paddle.bool), float("-inf")) - - _, topk_ids = paddle.topk(probs_for_choice, top_k, axis=-1) - topk_weights = paddle.index_sample(gate_probs, topk_ids) - - # normalize combine weights - if renormalize: - topk_weights = topk_weights / paddle.clip(topk_weights.sum(-1, keepdim=True), min=1e-12) - - # apply routed scaling factor - if routed_scaling_factor: - topk_weights = topk_weights * routed_scaling_factor - - return topk_weights, topk_ids - - class DeepGemmFusedMoeMethod(MoEMethodBase): """ DeepGemmFusedMoeMethod is a class that implements the MoEMethodBase interface for DeepGemm backend. @@ -403,22 +342,7 @@ def apply_ep_prefill( hidden_size = x.shape[1] # 1. Select topk experts and weights - if ( - fastdeploy.envs.FD_USE_PHI_MOE_TOPK - and layer.redundant_table_manger is None - and layer.topk_method == "noaux_tc" - ): - topk_weights, topk_idx = moe_topk_select( - gate_out, - layer.n_group, - layer.topk_group, - layer.top_k, - layer.routed_scaling_factor, - layer.gate_correction_bias, - getattr(layer, "renormalize", True), - ) - else: - topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) + topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out) if topk_ids_hookfunc is not None: topk_ids_hookfunc(topk_ids=topk_idx) @@ -431,7 +355,7 @@ def apply_ep_prefill( else: x_fp8, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) @@ -597,7 +521,7 @@ def apply_ep_prefill( using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) else: - token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) + token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts, False) ( permute_input, permute_scale, @@ -657,7 +581,8 @@ def apply_ep_prefill( else: ffn_in_x, ffn_in_x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( ffn_out, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 + or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) ffn_in_x_scale_tensor = ffn_in_x_scale_tensor.T[: ffn_in_x.shape[0]] @@ -817,32 +742,26 @@ def apply_tp( below is TP compute method. """ gate_out = gate(x) - gate_out = gate_out.cast("float32") if layer.topk_method == "noaux_tc": - - if not fastdeploy.envs.FD_USE_PHI_MOE_TOPK: - _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( - gate_out, - layer.n_group, - layer.topk_group, - layer.top_k, - layer.routed_scaling_factor, - layer.gate_correction_bias, - getattr(layer, "renormalize", True), - ) - else: - topk_weights, topk_ids = moe_topk_select( - gate_out, - layer.n_group, - layer.topk_group, - layer.top_k, - layer.routed_scaling_factor, - layer.gate_correction_bias, - getattr(layer, "renormalize", True), - ) - + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") + _, topk_weights, topk_ids = fastdeploy.model_executor.layers.moe.moe.get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + topk_reduce_func=getattr(layer, "topk_reduce_func", None), + use_fused_cast=use_fused, + ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, @@ -861,7 +780,7 @@ def apply_tp( else: recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=self.quant_config.deepgemm_scale_ue8m0, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) @@ -893,7 +812,7 @@ def apply_tp( using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) else: - tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts) + tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts, False) ( permute_input, permute_scale, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index d1db43a3241..be215065db3 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -20,6 +20,17 @@ from paddle import nn import fastdeploy +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores +from fastdeploy.model_executor.layers.moe.triton_moe_kernels import ( + fused_moe_kernel_bf16, + fused_moe_kernel_paddle, +) +from fastdeploy.model_executor.layers.quantization.fp8_utils import ( + fused_stack_transpose_quant, + quant_weight_ue8m0, + transform_scale_ue8m0, +) +from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import ( TensorTracker, @@ -28,23 +39,18 @@ set_weight_attrs, weight_fully_copied, ) +from fastdeploy.platforms import current_platform from fastdeploy.utils import ceil_div, register_custom_python_op -from ..quantization.quant_base import QuantMethodBase - try: - from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func + import triton.language as tl - from .triton_moe_kernels import fused_moe_kernel_paddle + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func except ImportError: pass -from fastdeploy.model_executor.layers.moe.moe import get_moe_scores -from fastdeploy.model_executor.layers.quantization.fp8_utils import ( - fused_stack_transpose_quant, - quant_weight_ue8m0, - transform_scale_ue8m0, -) -from fastdeploy.model_executor.layers.quantization.ops import scaled_fp8_quant + +from ..quantization.quant_base import QuantMethodBase +from .fused_moe_backend_base import UnquantizedFusedMoEMethod class TritonWeightOnlyMoEMethod(QuantMethodBase): @@ -299,7 +305,7 @@ def apply( if token_num == 0: return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) gate_out = gate(x) - gate_out = gate_out.cast("float32") + top_k = layer.top_k num_local_experts = layer.num_local_experts top_k = layer.top_k @@ -307,6 +313,11 @@ def apply( hidden_size = layer.hidden_size if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -315,8 +326,10 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, @@ -688,7 +701,6 @@ def apply( if token_num == 0: return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) gate_out = gate(x) - gate_out = gate_out.cast("float32") top_k = layer.top_k num_local_experts = layer.num_local_experts moe_intermediate_size = layer.moe_intermediate_size @@ -696,6 +708,11 @@ def apply( E, N1, _ = getattr(layer, self.added_weight_attrs[0]).shape if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") gate_out, topk_weights, topk_ids = get_moe_scores( gate_out, layer.n_group, @@ -704,8 +721,10 @@ def apply( layer.routed_scaling_factor, layer.gate_correction_bias, getattr(layer, "renormalize", True), + use_fused_cast=use_fused, ) else: + gate_out = gate_out.cast("float32") topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( gate_out, layer.gate_correction_bias, @@ -1247,7 +1266,7 @@ def python_op_fused_moe_kernel_paddle( x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, quant_config.weight_block_size[0], False) else: x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - x, using_pow2_scale=False, output_scale_transpose=False + x, using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=False ) x_scale = x_scale[: x.shape[0]] @@ -1305,7 +1324,9 @@ def python_op_fused_moe_kernel_paddle( ) else: x_q, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( - intermediate_cache2, using_pow2_scale=False, output_scale_transpose=False + intermediate_cache2, + using_pow2_scale=fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, + output_scale_transpose=False, ) x_scale = x_scale[: x_q.shape[0]] @@ -1836,3 +1857,241 @@ def apply( self.quant_config, topk_ids_hookfunc, ) + + +class TritonMoEMethod(UnquantizedFusedMoEMethod): + """ + Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE. + + Activated via: export FD_MOE_BACKEND=triton + Weight layout (CUDA path): [E, K, 2N] for up_gate_proj, [E, N, K] for down_proj. + This matches UnquantizedFusedMoEMethod.create_weights layout on CUDA. + """ + + def __init__(self, quant_config=None): + super().__init__(quant_config) + + def process_loaded_weights(self, layer: nn.Layer, state_dict): + """Stack individual expert weights into the stacked parameter.""" + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) + layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0)) + layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0)) + + def _get_default_config(self, M: int, E: int) -> dict: + """ + Heuristic tile config for BF16 MoE, ported verbatim from vLLM's + `get_default_config` (bf16/fp16 non-block_shape branch). + See vllm/model_executor/layers/fused_moe/fused_moe.py:1273-1319. + + M: number of tokens (A.size(0) in vLLM), i.e. pre-expansion token count. + E: number of (local) experts. + """ + + # Tile sizes scale with batch: small batches are memory-bound + # (favor tall-K tiles), large batches are compute-bound (favor + # large M/N tiles with more warps). + if M <= 32: + block_m = 16 + elif M <= 96: + block_m = 32 + elif M <= 512: + block_m = 64 + else: + block_m = 128 + + block_n = 64 if M <= 64 else 128 + + block_k = 64 + + # Grouping adjacent M-blocks lets them share weight tiles in L2. + # Only helps when there are enough M-blocks per expert to group; + # with many experts each one sees few tokens so grouping is useless. + tokens_per_expert = M // max(E, 1) + group_m = 16 if tokens_per_expert > 128 else 1 + + # Large batches have enough blocks to saturate the GPU, so we + # use more warps per block to increase arithmetic intensity. + num_warps = 4 if M <= 128 else 8 + + num_stages = 4 if M <= 32 else 3 + + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + } + + def apply_tp( + self, + layer: nn.Layer, + x: paddle.Tensor, + gate: nn.Layer, + topk_ids_hookfunc: Callable = None, + fc1_latent_proj: nn.Layer = None, + fc2_latent_proj: nn.Layer = None, + ) -> paddle.Tensor: + """ + BF16 Triton Fused MoE forward. + + Pipeline: + 1. Gate + topk routing + 2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded + 3. fused_moe_kernel_bf16 GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] + 4. SwiGLU activation + 5. fused_moe_kernel_bf16 GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] + (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) + 6. Reshape + sum over topk dim + """ + token_num = x.shape[0] + if token_num == 0: + return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) + + top_k = layer.top_k + num_local_experts = layer.num_local_experts + moe_intermediate_size = layer.moe_intermediate_size + hidden_size = layer.hidden_size + + # --- 1. Routing --- + gate_out = gate(x) + + if layer.topk_method == "noaux_tc": + use_fused = ( + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse and current_platform.is_cuda() + ) + if not use_fused: + gate_out = gate_out.cast("float32") + _, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + use_fused_cast=use_fused, + topk_reduce_func=getattr(layer, "topk_reduce_func", None), + ) + else: + gate_out = gate_out.cast("float32") + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + top_k, + True, # apply_norm_weight + False, + ) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids=topk_ids) + + # --- 2. Preprocess: sort tokens by expert assignment --- + num_token_expert_pairs = token_num * top_k + # vLLM convention: pass num_tokens (pre-expansion), NOT tokens*top_k. + cfg = self._get_default_config(token_num, num_local_experts) + + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( + topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] + ) + max_possible_num_post_padded = sorted_token_ids.shape[0] + + # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- + # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn + up_gate_proj_out = paddle.empty( + [num_token_expert_pairs, moe_intermediate_size * 2], + dtype=x.dtype, + ) + grid1 = ( + ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) + * ceil_div(moe_intermediate_size * 2, cfg["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_bf16[grid1]( + x, + layer.up_gate_proj_weight, + up_gate_proj_out, + None, # topk_weights_ptr (no weight mul on GEMM1) + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N=moe_intermediate_size * 2, + K=hidden_size, + EM=max_possible_num_post_padded, + num_valid_tokens=num_token_expert_pairs, + stride_am=x.strides[0], + stride_ak=x.strides[1], + stride_be=layer.up_gate_proj_weight.strides[0], + stride_bk=layer.up_gate_proj_weight.strides[1], + stride_bn=layer.up_gate_proj_weight.strides[2], + stride_cm=up_gate_proj_out.strides[0], + stride_cn=up_gate_proj_out.strides[1], + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=False, + top_k=top_k, + compute_type=tl.bfloat16, + even_Ks=(hidden_size % cfg["BLOCK_SIZE_K"] == 0), + num_warps=cfg["num_warps"], + num_stages=cfg["num_stages"], + ) + + # --- 4. SwiGLU activation --- + down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) + + # --- 5. GEMM2: inter -> hidden, fuse router weight multiplication --- + # down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn + down_proj_out = paddle.empty( + (num_token_expert_pairs, hidden_size), + dtype=x.dtype, + ) + grid2 = ( + ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) * ceil_div(hidden_size, cfg["BLOCK_SIZE_N"]), + ) + fused_moe_kernel_bf16[grid2]( + down_proj_input, + layer.down_proj_weight, + down_proj_out, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N=hidden_size, + K=moe_intermediate_size, + EM=max_possible_num_post_padded, + num_valid_tokens=num_token_expert_pairs, + stride_am=down_proj_input.strides[0], + stride_ak=down_proj_input.strides[1], + stride_be=layer.down_proj_weight.strides[0], + stride_bk=layer.down_proj_weight.strides[1], + stride_bn=layer.down_proj_weight.strides[2], + stride_cm=down_proj_out.strides[0], + stride_cn=down_proj_out.strides[1], + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], + MUL_ROUTED_WEIGHT=True, + top_k=1, + compute_type=tl.bfloat16, + even_Ks=(moe_intermediate_size % cfg["BLOCK_SIZE_K"] == 0), + num_warps=cfg["num_warps"], + num_stages=cfg["num_stages"], + ) + + # --- 6. Reduce over topk --- + down_proj_out.reshape_([token_num, top_k, hidden_size]) + out = down_proj_out.sum(axis=1) + return out + + def apply_ep_prefill( + self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None + ): + raise NotImplementedError("TritonMoEMethod does not support EP prefill yet.") + + def apply_ep_decode( + self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None, fc1_latent_proj=None, fc2_latent_proj=None + ): + raise NotImplementedError("TritonMoEMethod does not support EP decode yet.") diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index 4e56c7485f9..b2354327631 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -28,7 +28,7 @@ ) from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( - save_routing_to_buffer, + save_routing_to_buffer_v2, ) from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.utils import h2d_copy, slice_fn @@ -36,7 +36,11 @@ from fastdeploy.worker.experts_manager import RedundantExpertManger try: - from fastdeploy.model_executor.ops.gpu import noaux_tc, noaux_tc_redundant + from fastdeploy.model_executor.ops.gpu import ( + grouped_topk, + noaux_tc, + noaux_tc_redundant, + ) except: logger.warning("import noaux_tc Failed!") import numpy as np @@ -48,6 +52,11 @@ def get_moe_method(layer=None): """ if current_platform.is_cuda(): + moe_backend = envs.FD_MOE_BACKEND.lower() + if moe_backend == "triton": + from .fused_moe_triton_backend import TritonMoEMethod + + return TritonMoEMethod(None) from .fused_moe_cutlass_backend import CutlassMoEMethod return CutlassMoEMethod(None) @@ -90,14 +99,23 @@ def get_moe_scores( expert_in_rank_num_list: paddle.Tensor = None, tokens_per_expert_stats_list: paddle.Tensor = None, redundant_ep_rank_num_plus_one: int = 1, + topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, + use_fused_cast: bool = False, ) -> paddle.Tensor: """ compute moe scores using e_score_correction_bias. """ - scores = paddle.nn.functional.sigmoid(gating_output) assert e_score_correction_bias is not None, "e_score_correction_bias is none!" - scores_with_bias = scores + e_score_correction_bias - if expert_id_to_ep_rank_array is None: + if envs.FD_USE_PHI_MOE_TOPK: + # calculate renormalize and routed_scaling_factor value outside the noaux_tc + original_renormalize = renormalize + original_routed_scaling_factor = routed_scaling_factor + renormalize = False + routed_scaling_factor = 1.0 + + if expert_id_to_ep_rank_array is None and not use_fused_cast: + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx = noaux_tc( scores, scores_with_bias, @@ -107,9 +125,20 @@ def get_moe_scores( renormalize, routed_scaling_factor, ) + elif expert_id_to_ep_rank_array is None and use_fused_cast: + # fused kernel: cast + sigmoid + add + noaux_tc + scores, topk_values, topk_idx = grouped_topk( + gating_output, + e_score_correction_bias, + n_group if n_group > 0 else 1, + topk_group if topk_group > 0 else 1, + top_k, + renormalize, + routed_scaling_factor, + ) else: - # noaux_tc_redundant returns 4 values: scores, topk_values, topk_idx, - # and tokens_per_expert_stats_list_out (inplace updated) + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias scores, topk_values, topk_idx, _ = noaux_tc_redundant( scores, scores_with_bias, @@ -123,6 +152,16 @@ def get_moe_scores( routed_scaling_factor, redundant_ep_rank_num_plus_one, ) + if envs.FD_USE_PHI_MOE_TOPK: + if original_renormalize: + if topk_reduce_func is not None: + topk_values = topk_values / topk_reduce_func(topk_values) + else: + # 使用默认的 sum + epsilon + topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20) + + if original_routed_scaling_factor != 1.0: + topk_values *= original_routed_scaling_factor return scores, topk_values, topk_idx @@ -152,6 +191,8 @@ def __init__( with_bias: bool = False, activation="swiglu", model_format: Optional[str] = None, + topk_reduce_func: Callable = lambda x: x.sum(axis=-1, keepdim=True) + + 1e-20, # only used when FD_USE_PHI_MOE_TOPK=1, default is same as noaux_tc kernel ): """ Initialize the Moe layer with given parameters. @@ -197,6 +238,7 @@ def __init__( self.moe_tag = moe_tag self.with_bias = with_bias self.activation = activation + self.topk_reduce_func = topk_reduce_func if self.ep_size > 1: expert_id_offset = expert_id_offset + self.ep_rank * self.num_local_experts @@ -698,23 +740,23 @@ def forward( Tensor: Output tensor.s """ + topk_ids_hookfunc = None if self.enable_routing_replay: # When execute empty_input_forward forward_meta is None. When execute mtp layer routing_replay_table is None. - if forward_meta is not None and forward_meta.routing_replay_table is not None: + if forward_meta is not None and forward_meta.device_routing_buffer is not None: moe_layer_idx = self.layer_idx - self.fd_config.model_config.moe_layer_start_index topk_ids_hookfunc = partial( - save_routing_to_buffer, - routing_replay_table=forward_meta.routing_replay_table, - batch_id_per_token=forward_meta.batch_id_per_token, - seq_lens_decoder=forward_meta.seq_lens_decoder, - cu_seqlens_q=forward_meta.cu_seqlens_q, + save_routing_to_buffer_v2, + device_routing_buffer=forward_meta.device_routing_buffer, layer_idx=moe_layer_idx, tp_size=self.fd_config.parallel_config.tensor_parallel_size, ep_size=self.fd_config.parallel_config.expert_parallel_size, tp_group=self.fd_config.parallel_config.tp_group, + total_token_num=forward_meta.batch_id_per_token.shape[0], + position_ids=forward_meta.position_ids, + debug_mode=self.fd_config.routing_replay_config.debug_mode, ) - if current_platform.is_intel_hpu(): out = self.forward_normal( x, gate, forward_meta, topk_ids_hookfunc=topk_ids_hookfunc, shared_experts=shared_experts diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index b27957bf0c6..a4edbbfb724 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -14,20 +14,6 @@ # limitations under the License. """ -import asyncio -import atexit -import functools -import multiprocessing -import os -import shutil -import threading -import time -import traceback -from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor -from multiprocessing import Process, Queue -from typing import Dict, Optional, TypedDict - import numpy as np import paddle import paddle.distributed as dist @@ -35,7 +21,8 @@ import triton.language as tl from paddleformers.utils.log import logger -from fastdeploy.config import FDConfig, RoutingReplayConfig +from fastdeploy.cache_manager.routing_cache_manager import RoutingHostBufferView +from fastdeploy.config import FDConfig from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( enable_compat_on_triton_kernel, ) @@ -43,828 +30,313 @@ @enable_compat_on_triton_kernel @triton.jit -def _save_routing_kernel( - ROUTING_REPLAY_TABLE_PTR, +def _save_routing_kernel_v2( + device_routing_buffer_PTR, TOPK_IDS_PTR, - BATCH_ID_PER_TOKEN_PTR, - CU_SEQLENS_Q_PTR, - SEQ_LENS_DECODER_PTR, LAYER_IDX, TOKEN_NUM, TOP_K, - NUM_HIDDEN_LAYERS, - MAX_MODEL_LEN, + NUM_MOE_LAYERS, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_m = tl.program_id(axis=0) - token_offsets = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) token_mask = token_offsets < TOKEN_NUM - k_offsets = tl.arange(0, BLOCK_SIZE_K) - k_mask = k_offsets < TOP_K - topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] - # [BLOCK_SIZE_M, BLOCK_SIZE_K] - load_mask = token_mask[:, None] & k_mask[None, :] - topk_vals = tl.load(topk_ids_ptrs, mask=load_mask) - - batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask) - pad_mask = token_mask & (batch_ids != -1) - # [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3] - # -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] - # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] - # -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1] - start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask) - token_relative_index = token_offsets - start_offsets - - # [BLOCK_SIZE_M] - len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask) - token_seq_pos = len_decoder + token_relative_index - - STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K - STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K - STRIDE_BUF_LAYER = TOP_K - - # [BLOCK_SIZE_M, BLOCK_SIZE_K] + topk_vals = tl.load( + TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :], + mask=load_mask, + ) + + STRIDE_TOKEN = NUM_MOE_LAYERS * TOP_K + STRIDE_LAYER = TOP_K output_ptrs = ( - ROUTING_REPLAY_TABLE_PTR - + batch_ids[:, None] * STRIDE_BUF_SEQ - + token_seq_pos[:, None] * STRIDE_BUF_TOKEN - + LAYER_IDX * STRIDE_BUF_LAYER + device_routing_buffer_PTR + + token_offsets[:, None] * STRIDE_TOKEN + + LAYER_IDX * STRIDE_LAYER + k_offsets[None, :] ) + tl.store(output_ptrs, topk_vals, mask=load_mask) - pos_mask = token_seq_pos < MAX_MODEL_LEN - pos_mask = pos_mask & pad_mask - - # [BLOCK_SIZE_M, BLOCK_SIZE_K] - pos_mask = pos_mask[:, None] & k_mask[None, :] - final_mask = load_mask & pos_mask - - tl.store(output_ptrs, topk_vals, mask=final_mask) - - -def save_routing_to_buffer( - routing_replay_table: paddle.Tensor, # [max_num_seqs, num_layers, max_len, top_k] - topk_ids: paddle.Tensor, # [token_num, top_k] - batch_id_per_token: paddle.Tensor, # [token_num, 1] - seq_lens_decoder: paddle.Tensor, # [max_num_seqs, 1] - cu_seqlens_q: paddle.Tensor, # [max_num_seqs + 1, 1] +def save_routing_to_buffer_v2( + device_routing_buffer: paddle.Tensor, + topk_ids: paddle.Tensor, layer_idx: int, tp_size: int, ep_size: int, tp_group: dist.communication.group.Group, + total_token_num: int = -1, + position_ids: paddle.Tensor = None, + debug_mode: bool = False, ): + token_num_per_rank = topk_ids.shape[0] + if token_num_per_rank == 0: + return if tp_size > 1 and ep_size > 1: - token_num_per_rank = topk_ids.shape[0] - if token_num_per_rank == 0: - return topk_ids_all = paddle.zeros([token_num_per_rank * tp_size, topk_ids.shape[1]], dtype=topk_ids.dtype) paddle.distributed.all_gather(topk_ids_all, topk_ids, tp_group) - topk_ids = topk_ids_all[: batch_id_per_token.shape[0], :] + assert ( + total_token_num >= token_num_per_rank + ), f"[R3] total_token_num={total_token_num} < token_num_per_rank={token_num_per_rank}" + topk_ids = topk_ids_all[:total_token_num, :] + + if debug_mode and position_ids is not None: + token_num, top_k = topk_ids.shape + hack_ids = position_ids[:token_num].cast(topk_ids.dtype) + hack_ids = hack_ids.unsqueeze(1).expand([-1, top_k]) + topk_ids = hack_ids token_num, top_k = topk_ids.shape - max_num_seqs, max_model_len, num_hidden_layers, _ = routing_replay_table.shape - assert token_num > 0 - assert topk_ids.shape[1] == routing_replay_table.shape[3], (topk_ids.shape[1], routing_replay_table.shape[3]) - assert batch_id_per_token.shape[0] == token_num, (batch_id_per_token.shape[0], token_num) - assert seq_lens_decoder.shape[0] == max_num_seqs, (seq_lens_decoder.shape[0], max_num_seqs) + buf_max_tokens, num_moe_layers, buf_top_k = device_routing_buffer.shape - BLOCK_SIZE_M = 128 - BLOCK_SIZE_K = triton.next_power_of_2(top_k) # top_k + assert ( + token_num <= buf_max_tokens + ), f"[R3] token_num={token_num} exceeds device_routing_buffer capacity={buf_max_tokens}" + assert ( + top_k == buf_top_k + ), f"[R3] top_k mismatch: topk_ids.top_k={top_k} vs device_routing_buffer.top_k={buf_top_k}" + assert 0 <= layer_idx < num_moe_layers, f"[R3] layer_idx={layer_idx} out of range [0, {num_moe_layers})" + BLOCK_SIZE_M = 128 + BLOCK_SIZE_K = triton.next_power_of_2(top_k) grid = (triton.cdiv(token_num, BLOCK_SIZE_M),) - _save_routing_kernel[grid]( - routing_replay_table, + _save_routing_kernel_v2[grid]( + device_routing_buffer, topk_ids, - batch_id_per_token, - cu_seqlens_q, - seq_lens_decoder, LAYER_IDX=layer_idx, TOKEN_NUM=token_num, TOP_K=top_k, - NUM_HIDDEN_LAYERS=num_hidden_layers, - MAX_MODEL_LEN=max_model_len, + NUM_MOE_LAYERS=num_moe_layers, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_K=BLOCK_SIZE_K, ) -class RoutingReplayManager: - """Request level routing replay table manager""" +class RoutedExpertsCapturer: + """ + Worker-side routing capture: manages GPU transient buffer and GPU→CPU scatter. + Does NOT manage request lifecycle — that is handled by RoutingCacheManager on the Engine side. + """ - def __init__(self, fd_config: FDConfig, block_table, total_block_num): + def __init__(self, fd_config: FDConfig, total_block_num: int): self.fd_config = fd_config - self.block_table = block_table self.max_num_seqs = fd_config.scheduler_config.max_num_seqs - self.max_model_len = fd_config.model_config.max_model_len - self.num_moe_layers = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index - self.only_last_turn = fd_config.routing_replay_config.only_last_turn - self.use_fused_put = fd_config.routing_replay_config.use_fused_put - if fd_config.model_config.architectures[0] == "Glm4MoeForCausalLM": - self.moe_top_k = fd_config.model_config.num_experts_per_tok - else: - self.moe_top_k = fd_config.model_config.moe_k + + # Read routing params from centralized config + rrc = fd_config.routing_replay_config + self.num_moe_layers = rrc.num_moe_layers + self.moe_top_k = rrc.moe_top_k + self.routing_dtype = rrc.routing_dtype + self.debug_mode = rrc.debug_mode self.tp_rank = fd_config.parallel_config.tensor_parallel_rank + self.token_num_overlap = 0 - # Initialize the routing replay table and routing cache - self.routing_batch_to_request: Dict[int, str] = {} - num_experts = fd_config.model_config.moe_num_experts + fd_config.model_config.moe_num_shared_experts - self.routing_dtype = self.get_routing_dtype(num_experts=num_experts) - self._init_routing_cache(dtype=self.routing_dtype, total_block_num=total_block_num) - self.pending_update_positions = None + logger.info(f"[R3] RoutedExpertsCapturer config: {rrc}") - # Initialize routing store wrapper - if self.tp_rank == 0: - self._store_wrapper = StoreWrapper( - fd_config=fd_config, - ) - self._store_wrapper.start_store_warpper() + self._init_routing_cache(dtype=self.routing_dtype, total_block_num=total_block_num) def _init_routing_cache(self, dtype: str, total_block_num: int): - """Initialize the device buffer and host buffer.""" - + """Initialize GPU transient buffer, staging buffers, and CPU pinned buffers.""" max_num_kv_tokens = total_block_num * self.fd_config.cache_config.block_size + self.max_num_kv_tokens = max_num_kv_tokens # Save for slot range validation - self._host_cache = paddle.full( - shape=[max_num_kv_tokens, self.num_moe_layers, self.moe_top_k], fill_value=-1, dtype=dtype, device="cpu" - ) - - self.routing_replay_table = paddle.full( - shape=[self.max_num_seqs, self.max_model_len, self.num_moe_layers, self.moe_top_k], - fill_value=-1, - dtype=dtype, - ) - logger.info( - f"[R3] The host cache size is:{self._host_cache.shape}, device cache size is: {self.routing_replay_table.shape}" - ) - - def get_routing_dtype(self, num_experts: int, reserved_fill_value: int = 1) -> str: - """Calculate the minimum number of bits required for storage routing.""" - if num_experts <= 0: - raise ValueError(f"num_experts must be greater than 0 but got {num_experts}, please check model config.") - dtype = "uint8" - total_number = num_experts + reserved_fill_value - if total_number <= 255: # uint8: 0~255 - dtype = "uint8" - elif total_number <= 65535: # uint16: 0~65,535 - dtype = "uint16" - elif total_number <= 4294967295: # uint32: 0~4,294,967,295 - dtype = "uint32" - else: - raise ValueError( - f"The number of experts {num_experts} exceeds the representation range of uint32, please check model config." - ) - logger.info(f"[R3] Routing replay table dtype: {dtype}") - return dtype - - def update_host_cache(self, positions: paddle.Tensor, slot_mapping: paddle.Tensor): - """Update the host cache with new tokens""" - for batch_id, position in enumerate(positions): - if len(position) > 0 and len(slot_mapping[batch_id]) > 0: - routing_ids = self.routing_replay_table[batch_id, position, :, :].contiguous() - routing_ids = routing_ids.cpu() - - self._host_cache[slot_mapping[batch_id], :, :] = routing_ids - - def get_token_positions(self, seq_lens_decoder, seq_lens_this_time): - """Get token position of each sequence in a batch.""" - starts = seq_lens_decoder.numpy() - increase_num = seq_lens_this_time.numpy() - - positions = [] - for i in range(self.max_num_seqs): - if seq_lens_this_time[i] == 0: - positions.append([]) - continue - repeated_base = np.repeat(starts[i], increase_num[i]) - positions.append(repeated_base + np.arange(0, increase_num[i])) - - return positions - - def compute_slot_mapping(self, positions: np.ndarray): - """Compute the mapping between token ids and kvcache slots""" - slot_mapping = [] - for batch_id, position in enumerate(positions): - if len(position) == 0: - slot_mapping.append([]) - continue - block_table_indices = position // self.fd_config.cache_config.block_size - token_block_ids = self.block_table[batch_id, block_table_indices] - block_offset = position % self.fd_config.cache_config.block_size - - token_cache_ids = np.array(token_block_ids) * self.fd_config.cache_config.block_size + block_offset - slot_mapping.append(token_cache_ids) - - return slot_mapping - - def _get_routing_from_cache(self, finished_batch_ids, seq_lens_decoder): - """ - When request is finished or cleared the length of the request is recorded at seq_lens_decoder - 1. finish the step: after update input, lens = seq_lens_decoder_buffer - 2. clear parameter: after update input, lens = seq_lens_decoder_buffer - """ - # Get the slot mapping of the request cache. - current_token_nums = seq_lens_decoder.numpy() - positions = [] - for batch_id in range(self.max_num_seqs): - position = [] - if batch_id in finished_batch_ids: - position = np.arange(0, current_token_nums[batch_id]) - positions.append(position) - - # Collection the cached routing information - token_cache_ids = self.compute_slot_mapping(positions=positions) - for slot_map in token_cache_ids: - if len(slot_map) > 0: - token_cached_routing = self._host_cache[slot_map, :, :] - return paddle.transpose(token_cached_routing, [1, 0, 2]) - raise ValueError("No cached routing found") - - def put_finished_batch( - self, - finished_batch_ids, - seq_lens_decoder, - ): - finished_batch_ids_list = finished_batch_ids.cpu().tolist() - for batch_id, finished in enumerate(finished_batch_ids_list): - if finished: - assert batch_id in self.routing_batch_to_request.keys() - # Deregister the request - request_id = self._deregister_request(batch_id) - # Put the routing of finished request to store - self._put_request_to_store( - batch_id=batch_id, - request_id=request_id, - seq_lens_decoder=seq_lens_decoder, - ) - # Clear the slot of the finished batch - self._clear_table_slot(batch_id) - - def register_request(self, batch_id: int, request_id: str): - """ - Register a new request to routing replay table - Args: - batch_id: The batch ID of this request - request_id: The global ID of the request is usually executed by the training process in RL - """ - # The chunked prefill tasks will be registered repeatedly - if batch_id in self.routing_batch_to_request: - if self.routing_batch_to_request[batch_id] == request_id: - logger.warning(f"[R3] Request {request_id} has been registered at {batch_id}.") - return - else: - raise RuntimeError( - f"[R3] The Batch {batch_id} has been registered by request {self.routing_batch_to_request[batch_id]}, now robed by {request_id}," - ) - - # Register the new request - self.routing_batch_to_request[batch_id] = request_id - logger.info(f"[R3] Register request {request_id} with batch id {batch_id}") - - def _deregister_request(self, batch_id: int) -> str: - """ - Deregister a request from routing replay table - """ - assert batch_id in self.routing_batch_to_request - return self.routing_batch_to_request.pop(batch_id) - - def _put_request_to_store( - self, - batch_id: int, - request_id: str, - seq_lens_decoder, - ): - if self.tp_rank == 0: - before_put_request_time = time.perf_counter() - - # Collect the routing of finished request - batch_buffer = self._get_routing_from_cache( - finished_batch_ids=[batch_id], seq_lens_decoder=seq_lens_decoder - ) - rollout_id = self.split_request_id(request_id) - - if self.use_fused_put: - self._store_wrapper.submit_put_task(routing_indices=batch_buffer, rollout_id=rollout_id) - else: - for layer_id in range(self.num_moe_layers): - layer_buffer = batch_buffer[layer_id] - self._store_wrapper.submit_put_task( - routing_indices=layer_buffer, rollout_id=rollout_id, layer_idx=layer_id - ) - - # Only store the routing of last turn - if self.only_last_turn: - self._store_wrapper.submit_clear_prefix_batch_task(rollout_id=rollout_id) - - logger.info(f"[R3] Submit {request_id} time cost: {time.perf_counter() - before_put_request_time}") - - def clear_request(self, batch_id: int): - """Clear the routing indices of the request""" - self._clear_table_slot(batch_id) - self.routing_batch_to_request.pop(batch_id, None) - - def _clear_table_slot(self, batch_id: int): - assert 0 <= batch_id < self.max_num_seqs - self.routing_replay_table[batch_id].fill_(-1) - - def get_routing_table(self) -> paddle.Tensor: - return self.routing_replay_table - - def split_request_id(self, request_id: str): - """ - Split the request id to get rollout id. - - request_id: "chatcmpl-request.user-uuid" - rollout_id: "request.user" - example: "chatcmpl-xxx_xxx_epoch_15:2:2:1-d9f16c5c-65f6-4815-b44d-14e2c581907c_0" -> "xxx_xxx_epoch_15:2:2:1" - """ - chat_type, tmp_str = request_id.split("-", 1) - # NOTE(gongshaotian): only support chatcmpl now - assert ( - chat_type == "chatcmpl" - ), "Rollout Routing Replay only supports chatcmpl. Please check whether the request type and userid settings are correct." - reversed_tmp_str = tmp_str[::-1].split("-", 5) - rollout_id = reversed_tmp_str[-1][::-1] - return rollout_id - - def clear_all_request(self): - """Clear all requests""" - self.routing_replay_table.fill_(-1) - self.routing_batch_to_request = {} - - -class StoreWrapper(object): - def __init__(self, fd_config: False) -> None: - super().__init__() - self.fd_config = fd_config - - # Initialize task queue - moe_layer_num = fd_config.model_config.num_hidden_layers - fd_config.model_config.moe_layer_start_index - max_num_seqs = fd_config.scheduler_config.max_num_seqs - self.queue_max_size = moe_layer_num * max_num_seqs * 1000 + # Small GPU transient buffer: only current step's token routing + # TODO(Chengyanfu): Use max_num_batched_tokens to replace get_max_chunk_tokens() + max_num_batched_tokens = self.fd_config.get_max_chunk_tokens() + shape = [max_num_batched_tokens, self.num_moe_layers, self.moe_top_k] - self.manager = multiprocessing.Manager() - self._task_queue = self.manager.Queue(maxsize=self.queue_max_size) + self.device_routing_buffer = paddle.full(shape=shape, fill_value=-1, dtype=dtype) + self.routing_staging_buf = paddle.full(shape=shape, fill_value=-1, dtype=dtype) + self.slot_mapping_staging_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64) - self._monitor_thread: threading.Thread = None - self._stop_monitor = threading.Event() - - # Initialize consumer process - self._routing_store_process = StoreProcess( - task_queue=self._task_queue, - routing_replay_config=self.fd_config.routing_replay_config, - max_model_len=self.fd_config.model_config.max_model_len, - ) - self._sotre_process_running = False - - # Register atexit handler - atexit.register(self.shutdown) - - def shutdown(self): - """ """ - if not self._sotre_process_running: - return - self._sotre_process_running = False - - # Stop the monitor thread - self._stop_monitor.set() - if self._monitor_thread and self._monitor_thread.is_alive(): - self._monitor_thread.join(timeout=3.0) - - # Put a sentinel value to signal the consumer to stop - if self._routing_store_process and self._routing_store_process.is_alive(): - try: - self._task_queue.put_nowait(None) - except Exception as e: - logger.info(f"Could not put sentinel into queue: {e}") - - if self._routing_store_process and self._routing_store_process.is_alive(): - # Wait for all tasks to be processed - self._routing_store_process.join(timeout=10.0) - if self._routing_store_process.is_alive(): - self._routing_store_process.close() - self._routing_store_process.join() - - self._task_queue.join() - self.manager.shutdown() - self._sotre_process_running = False - - def start_store_warpper(self): - """ """ - if self._sotre_process_running: - return - self._sotre_process_running = True - - # Start monitor thread - self._stop_monitor.clear() - self._monitor_thread = threading.Thread(target=self._monitor_queue_load, daemon=True) - self._monitor_thread.start() - - # Start Routing Store Wrapper in sub process - self._routing_store_process.start() - - def _monitor_queue_load(self): - """ """ - while not self._stop_monitor.is_set(): - time.sleep(2.0) - if not self._sotre_process_running: - break - qsize = self._task_queue.qsize() - - # Alarm when the task exceeds 80% of the queue capacity - if qsize > self.queue_max_size * 0.8: - logger.warning( - f"[Monitor] Queue load is HIGH: {qsize}/{self.queue_max_size}. " - f"Dropped tasks so far: {self._dropped_tasks}. " - "Consider increasing max_workers or queue_max_size." - ) - logger.debug(f"[Monitor] Queue load: {qsize}/{self.queue_max_size}") + self.cpu_routing_buf = paddle.zeros(shape, dtype=dtype).pin_memory() + self.cpu_slot_mapping_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64).pin_memory() - def submit_put_task(self, routing_indices: paddle.Tensor, rollout_id: str, layer_idx: int = None) -> None: - """Submit a put task to the task queue""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - - start_time = time.perf_counter() - if layer_idx is not None: - rdma_rollout_key = f"{rollout_id}_{layer_idx}" + if self.debug_mode: + self.position_ids_staging_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64) + self.cpu_position_ids_buf = paddle.zeros([max_num_batched_tokens], dtype=paddle.int64).pin_memory() else: - rdma_rollout_key = rollout_id - - routing_indices_np = routing_indices.numpy() + self.position_ids_staging_buf = None + self.cpu_position_ids_buf = None - task: StoreTask = {"task_type": "put", "key": rdma_rollout_key, "data": routing_indices_np} - - try: - self._task_queue.put_nowait(task) - except Exception: - raise RuntimeError(f"Queue is FULL. Dropping put task for key: {rdma_rollout_key}. ") - logger.info(f"[R3] Submit put task for key: {rdma_rollout_key}, cost time: {time.perf_counter()-start_time} s") + self._pending_save = None # {"num_tokens": int} - def submit_clear_store_task(self) -> None: - """Submit clear store task""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - - start_time = time.perf_counter() - task: StoreTask = {"task_type": "clear_store", "key": None, "data": None} + # Lazy attach to SharedMemory routing_host_buffer (created by Engine after profiling) + self.routing_host_view = None + self._routing_host_view_attach_attempted = False + self._routing_host_view_shm_name = ( + f"routing_host_buffer.{str(self.fd_config.parallel_config.local_engine_worker_queue_port)}" + ) + self._routing_host_view_shape = (max_num_kv_tokens, self.num_moe_layers, self.moe_top_k) + self._routing_host_view_dtype = dtype - try: - self._task_queue.put_nowait(task) - # Wait for the task to be processed - self._task_queue.join() - except Exception: - raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") - logger.info(f"[R3] Submit clear task, cost time: {time.perf_counter()-start_time} s") - - def submit_clear_prefix_batch_task(self, rollout_id) -> None: - """Submit clear prefix batch task""" - if not self._sotre_process_running: - raise RuntimeError("Store not started.") - prefix_batch = self.get_needed_clear_ids(rollout_id) - - if prefix_batch is None: - return - start_time = time.perf_counter() - task: StoreTask = {"task_type": "clear_prefix_batch", "key": prefix_batch, "data": None} - try: - self._task_queue.put_nowait(task) - except Exception: - raise RuntimeError("Queue is FULL. Dropping put task for key: clear_store. ") + gpu_buffer_bytes = int(np.prod(self.device_routing_buffer.shape)) * np.dtype(dtype).itemsize logger.info( - f"[R3] Submit clear prefix batch task for key: {prefix_batch}, cost time: {time.perf_counter()-start_time} s" + f"[R3] GPU transient routing buffer: {self.device_routing_buffer.shape} " + f"({gpu_buffer_bytes / 1024:.1f} KB)" ) - def get_needed_clear_ids(self, roullout_id: str) -> Optional[str]: - """ - Generate the prefix IDs for all closed multi-round tasks. - rollout_id: "xxx_xxx_epoch_15:2:2:1" - example: xxx_xxx_data_id:gen_id:turn_id:segment_id - """ - reversed_segment_id, reversed_turn_id, reversed_prefix_gen_id = roullout_id[::-1].split(":", 2) - prefix_gen_id = reversed_prefix_gen_id[::-1] - turn_id = eval(reversed_turn_id[::-1]) - segment_id = eval(reversed_segment_id[::-1]) - - assert turn_id >= 0 and segment_id >= 0 - prefix_batch = None - if turn_id > 0: - prefix_batch = f"{prefix_gen_id}:{(turn_id-1)}:{segment_id}" - return prefix_batch - - -class StoreTask(TypedDict): - task_type: str - key: str - data: np.ndarray - - -class StoreProcess(Process): - def __init__(self, task_queue: Queue, routing_replay_config: RoutingReplayConfig, max_model_len: int) -> None: - super().__init__() - self.max_model_len = max_model_len - self._task_queue = task_queue - self.routing_replay_config = routing_replay_config - self.max_workers = 5 - self._closed = False - - # Note: _routing_store and _event_loop_thread must be initialized in run() - # because they cannot be properly inherited after fork() - self._routing_store = None - self._event_loop_thread = None - - def run(self): - logger.info(f"[R3] Start Running Store Wrapper in sub process {os.getpid()}") - - # Initialize routing store in subprocess - self._routing_store = get_routing_store(routing_replay_config=self.routing_replay_config) - - # Initialize event loop thread in subprocess - self._event_loop_thread = AsyncEventLoopThread() - self._event_loop_thread.start() - if not self._event_loop_thread._started_event.wait(timeout=5.0): - raise RuntimeError("Failed to start async event loop thread in subprocess") - - clear_store_task = StoreTask({"task_type": "clear_store", "key": None, "data": None}) - self._task_queue.put_nowait(clear_store_task) - - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - while not self._closed: - try: - task = self._task_queue.get() - if task is None: # Sentinel - self._task_queue.task_done() - break - - if task["task_type"] == "put": - future = executor.submit(self.process_put_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - elif task["task_type"] == "clear_store": - future = executor.submit(self.process_clear_store_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - elif task["task_type"] == "clear_prefix_batch": - future = executor.submit(self.process_clear_prefix_batch_task, task) - future.add_done_callback(lambda f: self._task_queue.task_done()) - except Exception as e: - self._task_queue.task_done() - raise RuntimeError(f"Error during processing task. {e}") - - logger.info(f"[Consumer Process {Process.current_process().pid}] Shutdown.") - - def process_put_task(self, store_task: StoreTask) -> None: - try: - # TODO(gongshaotian): delete this after trainer support dynamic len - store_task["data"] = self.pad_routing_indices(store_task["data"]) - coro_obj = self._routing_store.put(routing_key=store_task["key"], routing_indices=store_task["data"]) - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error submitting put task: {e}") - traceback.print_exc() - raise - - def process_clear_store_task(self, store_task: StoreTask) -> None: - try: - coro_obj = self._routing_store.clear_store() - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error during processing clear store task. {e}") - traceback.print_exc() - raise - - def process_clear_prefix_batch_task(self, store_task: StoreTask) -> None: - try: - coro_obj = self._routing_store.clear_prefix_batch(routing_prefix_key=store_task["key"]) - future = self._event_loop_thread.submit_coroutine( - coro_obj, callback=functools.partial(self._on_async_task_completed, store_task) - ) - return future - except Exception as e: - logger.error(f"Error submitting clear_prefix_batch task: {e}") - traceback.print_exc() - raise - - def _on_async_task_completed(self, task, future): - """ """ + def _try_attach_routing_host_view(self): + """Lazily attach to SharedMemory routing_host_buffer on first use.""" + if self._routing_host_view_attach_attempted: + return + self._routing_host_view_attach_attempted = True try: - # result = future.result() - logger.info(f"[R3] Async task completed: {task['task_type']}, key: {task['key']}") - except Exception as e: - logger.error(f"[R3] Async task failed: {task['task_type']}, key: {task['key']}, error: {e}") - traceback.print_exc() - raise - - def close(self): - """Close the store process""" - self._closed = True - if hasattr(self, "_event_loop_thread"): - self._event_loop_thread.stop() - - def pad_routing_indices(self, routing_indices: np.ndarray) -> np.ndarray: - """Pad routing indices of the request levevl to max model len""" - routing_shape = routing_indices.shape - if len(routing_shape) == 2: # [token, topk] - pad_array = np.full( - shape=[(self.max_model_len - routing_indices.shape[0]), routing_indices.shape[1]], - fill_value=-1, - dtype=routing_indices.dtype, + self.routing_host_view = RoutingHostBufferView( + shape=self._routing_host_view_shape, + dtype=self._routing_host_view_dtype, + shm_name=self._routing_host_view_shm_name, ) - return np.concatenate([routing_indices, pad_array], axis=0) - - elif len(routing_shape) == 3: # [layer, token, topk] - pad_array = np.full( - shape=[ - routing_indices.shape[0], - (self.max_model_len - routing_indices.shape[1]), - routing_indices.shape[2], - ], - fill_value=-1, - dtype=routing_indices.dtype, + logger.info(f"[R3] Attached to RoutingHostBuffer SharedMemory: {self._routing_host_view_shm_name}") + except FileNotFoundError: + logger.warning( + f"[R3] RoutingHostBuffer SharedMemory {self._routing_host_view_shm_name} not found. " + "Routing capture will be skipped." ) - return np.concatenate([routing_indices, pad_array], axis=1) - else: - raise ValueError(f"Invalid routing indices shape: {routing_shape}") - -class AsyncEventLoopThread(threading.Thread): - def __init__(self): - super().__init__(daemon=True) - self._loop = None - self._started_event = threading.Event() - self._closed = False - - def run(self): - """Run the async event loop""" - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - # Set the event loop to be started - self._started_event.set() - logger.info("[EventLoopThread] Event loop started, running forever...") - - try: - self._loop.run_forever() - logger.info("[EventLoopThread] Event loop stopped") - except Exception as e: - logger.error(f"[EventLoopThread] Event loop exception: {e}") - traceback.print_exc() - finally: - logger.info("[EventLoopThread] Closing event loop") - self._loop.close() - - def submit_coroutine(self, coro, callback=None): - """Thread safely submit coroutine to event loop""" - if self._closed: - raise RuntimeError("Event loop thread is closed") - if not self._started_event.wait(timeout=5.0): - raise RuntimeError("Event loop failed to start within 5 seconds") - - future = asyncio.run_coroutine_threadsafe(coro, self._loop) - - if callback: - - def wrapped_callback(f): - try: - callback(f) - except Exception as e: - logger.error(f"Error in callback: {e}") - traceback.print_exc() - - future.add_done_callback(wrapped_callback) - return future - - def stop(self): - """Stop the event loop""" - if not self._closed: - self._closed = True - if self._loop: - self._loop.call_soon_threadsafe(self._loop.stop) - - -class RoutingStoreBase(ABC): - """Base class for routing store""" - - def __init__(self, routing_replay_config: RoutingReplayConfig) -> None: - self.routing_replay_config = routing_replay_config - - @abstractmethod - async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: - """Put the routing indices into store""" - raise NotImplementedError - - @abstractmethod - async def clear_store( - self, + def prepare_pending_save( + self, num_tokens: int, slot_mapping_gpu: paddle.Tensor, position_ids_gpu: paddle.Tensor = None ): - """Clear the routing indices store""" - raise NotImplementedError - - @abstractmethod - async def clear_prefix_batch(self, routing_prefix_key: str): - """Clear the routing indices""" - raise NotImplementedError - - -class RoutingStoreLocal(RoutingStoreBase): - """Routing Store using local memory""" - - def __init__(self, routing_replay_config) -> None: - super().__init__(routing_replay_config=routing_replay_config) - self.local_store_dir = routing_replay_config.local_store_dir - os.makedirs(self.local_store_dir, exist_ok=True) - - async def put( - self, - routing_key: str, - routing_indices: np.ndarray, - ) -> None: - """Put the routing indices into store""" - # TODO(gongshaotian) covert ./store_dir/routing_key/layer_id.pdtensor to ./store_dir/routing_key.pdtensor - time_before_put = time.perf_counter() - - if len(routing_indices.shape) == 2: - re_layer_id, re_rollout_id = routing_key[::-1].split("_", 1) - rollout_id = re_rollout_id[::-1] - layer_id = re_layer_id[::-1] - request_path = os.path.join(self.local_store_dir, rollout_id) - file_path = os.path.join(request_path, f"layer_{layer_id}.pdtensor") - elif len(routing_indices.shape) == 3: - request_path = os.path.join(self.local_store_dir, routing_key) - file_path = os.path.join(request_path, f"{routing_key}.pdtensor") + """ + Enqueue D2D + async D2H for routing data and slot_mapping. + Must be called before post_process_event.record(). + All ops are enqueued on the current CUDA stream; CPU returns immediately. + + 1. D2D (non-blocking): device_routing_buffer → routing_staging_buf + 2. D2D (non-blocking): slot_mapping_gpu → slot_mapping_staging_buf + 3. async D2H: routing_staging_buf → cpu_routing_buf + 4. async D2H: slot_mapping_staging_buf → cpu_slot_mapping_buf + 5. async D2H (debug mode): position_ids_gpu → cpu_position_ids_buf + """ + if num_tokens > 0: + if self.fd_config.scheduler_config.enable_overlap_schedule: + num_tokens = self.token_num_overlap + slot_mapping_gpu = slot_mapping_gpu[:num_tokens] + if position_ids_gpu is not None: + position_ids_gpu = position_ids_gpu[:num_tokens] + + # D2D: GPU → staging + self.routing_staging_buf.copy_(self.device_routing_buffer, False) + self.slot_mapping_staging_buf.copy_(slot_mapping_gpu, False) + # Async D2H: staging → CPU pinned + self.cpu_routing_buf.copy_(self.routing_staging_buf, False) + self.cpu_slot_mapping_buf.copy_(self.slot_mapping_staging_buf, False) + + if self.debug_mode and position_ids_gpu is not None and self.cpu_position_ids_buf is not None: + self.position_ids_staging_buf.copy_(position_ids_gpu, False) + self.cpu_position_ids_buf.copy_(self.position_ids_staging_buf, False) + + self._pending_save = {"num_tokens": num_tokens} else: - raise ValueError(f"Invalid routing indices shape: {routing_indices.shape}") + self._pending_save = None - paddle.save(routing_indices, file_path) - logger.info(f"[R3] The routing key {routing_key} put cost is {time.perf_counter()-time_before_put}s") + def flush_pending_save(self): + """ + Pure CPU operation. Called after post_process_event.synchronize(), + which guarantees all D2D and D2H transfers have completed. + Scatter from CPU pinned buffers to SharedMemory. + """ + pending = self._pending_save + if pending is None: + return + self._pending_save = None - async def clear_store(self): - """Clear the routing indices store""" - if os.path.isdir(self.local_store_dir): - shutil.rmtree(self.local_store_dir) + if self.routing_host_view is None: + if not self._routing_host_view_attach_attempted: + self._try_attach_routing_host_view() + if self.routing_host_view is None: + return + + num_tokens = pending["num_tokens"] + # NOTE(gongshaotian): Slice pinned memory tensor maybe cause problem. + data = self.cpu_routing_buf.cpu()[:num_tokens].numpy() + slot_cpu = self.cpu_slot_mapping_buf.cpu() + slot_cpu_slice = slot_cpu[:num_tokens] + slot_np = slot_cpu_slice.numpy() + + if self.debug_mode and self.cpu_position_ids_buf is not None: + position_ids = self.cpu_position_ids_buf.cpu()[:num_tokens].numpy() + expected_routing = position_ids[:, None, None] + expected_routing = np.broadcast_to(expected_routing, (num_tokens, self.num_moe_layers, self.moe_top_k)) + if not np.array_equal(data, expected_routing): + # 1. Check routing capture + mismatch_mask = (data != expected_routing).any(axis=(1, 2)) + mismatched_token_indices = np.where(mismatch_mask)[0] + logger.error( + f"[R3 Debug] flush mismatch! num_tokens={num_tokens}, mismatched_tokens={len(mismatched_token_indices)}" + ) + logger.error(f"Mismatched token indices: {mismatched_token_indices}") + for idx in mismatched_token_indices: + logger.error( + f" token={idx}, position_id={position_ids[idx]}, slot={slot_np[idx]}, " + f"expected={expected_routing[idx, :, :]}, actual={data[idx, :, :]}" + ) + raise ValueError("Routing data verification failed.") + else: + # 2. Check slot mapping generation and validate slot indices (should be >= 0) + if slot_cpu_slice.min() < 0: + error_parts = [f"[R3 Debug] Invalid slot indices: num_tokens={num_tokens}"] + error_parts.append(" token |slot_staging | slot_pinned | slot_cpu | position_id | data[0,0]") + error_parts.append(" " + "-" * 50) + for i in range(num_tokens): + error_parts.append( + f" {i:4d} | {int(self.slot_mapping_staging_buf[i]):7d} | {int(self.cpu_slot_mapping_buf[i]):7d} | {int(slot_cpu[i]):7d} | {int(position_ids[i]):11d} | {int(data[i, 0, 0])}" + ) + raise AssertionError("\n".join(error_parts)) + # 2.1 Check slot range (should be < max_num_kv_tokens) + max_slot = slot_cpu_slice.max() + if max_slot >= self.max_num_kv_tokens: + invalid_slots = np.where(slot_np >= self.max_num_kv_tokens)[0] + error_parts = [ + f"[R3 Debug] Slot indices out of range: num_tokens={num_tokens}, " + f"max_slot={max_slot}, max_num_kv_tokens={self.max_num_kv_tokens}" + ] + error_parts.append(f" Invalid slot indices: {invalid_slots[:10]}... ({len(invalid_slots)} total)") + error_parts.append(" token |slot | position_id | data[0,0]") + error_parts.append(" " + "-" * 50) + for idx in invalid_slots[:10]: + error_parts.append( + f" {idx:4d} | {int(slot_np[idx]):6d} | {int(position_ids[idx]):11d} | {int(data[idx, 0, 0])}" + ) + raise AssertionError("\n".join(error_parts)) + # 3. Check slot mapping duplicates + unique_slots, counts = np.unique(slot_np, return_counts=True) + num_unique = len(unique_slots) + num_duplicates = np.sum(counts > 1) + if num_duplicates > 0: + duplicate_indices = np.where(counts > 1)[0] + dup_slots_info = [] + for slot_idx in duplicate_indices[:5]: + slot = unique_slots[slot_idx] + count = counts[slot_idx] + dup_token_indices = np.where(slot_np == slot)[0] + dup_slots_info.append(f"slot={slot} count={count} indices={dup_token_indices}") + logger.error( + f"[R3 Debug] flush validation passed but found duplicate slots! " + f"num_tokens={num_tokens}, unique_slots={num_unique}, duplicates={num_duplicates}. " + f"Details: {'; '.join(dup_slots_info)}" + ) + else: + logger.debug( + f"[R3 Debug] flush validation passed: num_tokens={num_tokens}, " + f"slots=[{slot_np[0]}...{slot_np[-1]}], unique_slots={num_unique}" + ) - logger.info("[R3] Clear routing store.") + self.routing_host_view.scatter(slot_np, data) - async def clear_prefix_batch(self, routing_prefix_key: str): - """Clear the routing indices""" - raise NotImplementedError + def get_device_routing_buffer(self) -> paddle.Tensor: + return self.device_routing_buffer + def clear(self): + """Clear GPU buffer and pending save state. Used during RL round cleanup.""" + self.device_routing_buffer.fill_(-1) + self._pending_save = None -class RoutingStoreRDMA(RoutingStoreBase): - """Routing Store using RDMA""" - def __init__(self, routing_replay_config) -> None: - super().__init__(routing_replay_config=routing_replay_config) - try: - # Only used in RLHF - from p2pstore import P2PClient, P2PConfig - except ModuleNotFoundError: - raise ModuleNotFoundError(" RoutingStoreRDMA and p2pstore only support in RLHF. ") - - rdma_store_server = routing_replay_config.rdma_store_server - p2pConfig = P2PConfig(metadata_server=rdma_store_server) - self.p2p_client = P2PClient(p2pConfig) - - async def put(self, routing_key: str, routing_indices: np.ndarray) -> None: - """Put the routing indices into store""" - time_before_put = time.perf_counter() - result = await self.p2p_client.put(routing_key, routing_indices) - logger.info(f"[R3] The routing key {routing_key}, put cost is {time.perf_counter()-time_before_put}s") - return result - - async def clear_prefix_batch(self, routing_prefix_key: str): - time_before_clear = time.perf_counter() - result = await self.p2p_client.delete_prefix_batch([routing_prefix_key]) - logger.info( - f"[R3] The clear routing prefix key {routing_prefix_key}, cost is {time.perf_counter()-time_before_clear}s" - ) - return result - - async def clear_store(self): - """Clear the routing indices store""" - time_before_clear = time.perf_counter() - result = await self.p2p_client.clear() - logger.info(f"[R3] Clear routing store cost is {time.perf_counter()-time_before_clear}s.") - return result - - -def get_routing_store(routing_replay_config: RoutingReplayConfig) -> RoutingStoreBase: - if routing_replay_config.routing_store_type == "local": - return RoutingStoreLocal(routing_replay_config=routing_replay_config) - elif routing_replay_config.routing_store_type == "rdma": - return RoutingStoreRDMA(routing_replay_config=routing_replay_config) - else: - raise ValueError( - f"Invalid routing store type: '{routing_replay_config.routing_store_type}'. " - "Valid types are: 'local', 'rdma'" - ) +# Backward compatibility alias +RoutingReplayManager = RoutedExpertsCapturer diff --git a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py index ac5dfa96fcc..ac9e18480b6 100644 --- a/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/moe/triton_moe_kernels.py @@ -198,3 +198,142 @@ def fused_moe_kernel_paddle( c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) + + +# --------------------------------------------------------------------------- +# BF16-native MoE kernel, ported from vLLM fused_moe_kernel (BF16-only path). +# +# Key differences from fused_moe_kernel_paddle (the wint8/fp8 kernel above): +# 1. compute_type is a tl.constexpr parameter (not hardcoded bfloat16). +# 2. offs_token is cast to int64 to prevent stride-multiplication overflow. +# 3. b matrix load always uses a K-boundary mask (no even_Ks special path). +# 4. Router-weight multiplication is done in fp32 before the final cast. +# 5. No quantization paths (use_fp8/int8 removed for clarity). +# --------------------------------------------------------------------------- +@enable_compat_on_triton_kernel +@triton.jit +def fused_moe_kernel_bf16( # pragma: no cover -- Triton JIT; body compiles to GPU code + # Pointers + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Dimensions (runtime scalars) + N, + K, + EM, + num_valid_tokens, + # Strides + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters (compile-time constants) + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + # naive_block_assignment: tl.constexpr = False, + even_Ks: tl.constexpr = False, +): + """ + BF16 Fused-MoE GEMM kernel, ported from vLLM. + + A: [num_tokens, K] – input activations (bf16) + B: [E, K, N] – expert weights (bf16) + C: [num_tokens * top_k, N] – output (bf16) + + sorted_token_ids: [EM] flat token-expert pair indices (int32) + expert_ids: [EM // BLOCK_SIZE_M] expert index per M-block (int32) + + When naive_block_assignment=True, each M-block processes exactly one + token-expert pair (skipping the preprocess/sort step). In this mode: + - expert_ids[pid_m] holds the expert index for token-expert pair pid_m + - sorted_token_ids_ptr is unused + - offs_token is constructed as [pid_m, invalid, invalid, ...] + This avoids the preprocess kernel overhead for very small token counts. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + offs = tl.arange(0, BLOCK_SIZE_M) + + offs_token_id = pid_m * BLOCK_SIZE_M + offs + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + + # if not naive_block_assignment: + # offs_token_id = pid_m * BLOCK_SIZE_M + offs + # offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + # else: + # # Each block handles exactly one token-expert pair: + # # row 0 = pid_m (the token-expert pair index), remaining rows are + # # set to num_valid_tokens which will fail the < mask check. + # offs_token = tl.where(offs == 0, pid_m, num_valid_tokens) + + # Cast to int64 to prevent overflow: stride_cm * offs_token can exceed int32 + offs_token = offs_token.to(tl.int64) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # A pointer: a_ptr[token_idx, :K] where token_idx = offs_token // top_k + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) + + # B pointer: b_ptr[expert, :K, offs_bn] — B layout is [E, K, N] + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if even_Ks: + a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0, + ) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Router-weight multiplication in fp32 (before precision conversion) + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 14e248e0a72..8efe0056eb5 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -35,6 +35,7 @@ is_batch_invariant_mode_enabled, rms_norm_batch_invariant, ) +from .flashinfer_comm_fusion import flashinfer_allreduce_residual_rmsnorm from .utils import get_tensor, modules_to_convert @@ -122,6 +123,11 @@ def __init__( self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank self.tp_group = self.fd_config.parallel_config.tp_group is_input_norm = prefix.endswith(".input_layernorm") + self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and ( + ("post_attention_layernorm" in prefix) + or (("input_layernorm" in prefix and layer_id != 0) and not fd_config.parallel_config.use_ep) + ) + self.is_last_norm = prefix.endswith(".norm") self.split_x = ( self.fd_config.parallel_config.use_sequence_parallel_moe @@ -234,12 +240,25 @@ def forward( if residual_input is None: residual_out = x + use_allreduce_fused = ( + self.enable_all_reduce_fusion + and self.tp_size > 1 + and x.shape[0] <= 2048 + and residual_input is not None + and current_platform.is_cuda() + ) if proxy_rmsnorm is None: if current_platform.is_gcu(): if residual_input is None: norm_out = rms_norm(x, self.weight, self.eps) return norm_out.astype(x_dtype), residual_out norm_out = self.norm_func(x, residual_input, self.weight, self.eps) + # enable trtllm all reduce fusion + elif use_allreduce_fused: + norm_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps + ) + assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!" else: if is_batch_invariant_mode_enabled(): # M-invariant path: per-row Triton kernel, no cross-row reduction @@ -261,9 +280,19 @@ def forward( quant_min_bound=self.quant_min_bound, ) else: - if residual_input is not None: - x = x + residual_input - norm_out = proxy_rmsnorm(x, self.weight, self.eps), x + if use_allreduce_fused: + norm_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, + input_tensor=x, + residual=residual_input, + weight=self.weight, + eps=self.eps, + ) + assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!" + else: + if residual_input is not None: + x = x + residual_input + norm_out = proxy_rmsnorm(x, self.weight, self.eps), x out = norm_out[0].astype(x_dtype) if residual_input is not None: @@ -341,7 +370,7 @@ def forward( forward_meta, proxy_rmsnorm=None, ) -> paddle.Tensor: - if proxy_rmsnorm is None and self.qk_norm_fused and forward_meta.step_use_cudagraph: + if proxy_rmsnorm is None and self.qk_norm_fused: qkv_out = qk_rmsnorm_fused( qkv_out, self.q_norm.weight, diff --git a/fastdeploy/model_executor/layers/quantization/__init__.py b/fastdeploy/model_executor/layers/quantization/__init__.py index 3e9e34c54ab..2c9992b18b5 100644 --- a/fastdeploy/model_executor/layers/quantization/__init__.py +++ b/fastdeploy/model_executor/layers/quantization/__init__.py @@ -54,17 +54,56 @@ def _compute_hadamard_block_size(moe_intermediate_size: int, tp_size: int) -> in return block_size +def _is_full_quantization_config(quantization_dict): + """ + Determine whether the parsed quantization dict is a simple method name or a full quantization_config. + Simple method name: {"quantization": "wint4"} (only one key "quantization") + Full config: {"quantization": "mix_quant", "dense_quant_type": "wint8", ...} (multiple keys) + Or torch format: {"quant_method": "fp8", "weight_block_size": [128, 128]} (has "quant_method" key) + """ + if "quant_method" in quantization_dict: + return True + if len(quantization_dict) > 1: + return True + return False + + def parse_quant_config(args, model_config, is_ernie, is_v1_loader): if args.quantization is not None and isinstance(args.quantization, str): args.quantization = parse_quantization(args.quantization) + + # Determine whether CLI --quantization is a simple method name or a full JSON quantization_config + cli_quantization = args.quantization + cli_is_full_config = ( + cli_quantization is not None + and isinstance(cli_quantization, dict) + and _is_full_quantization_config(cli_quantization) + ) + + model_quantization_config = model_config.quantization_config + quantization_config = model_quantization_config + + # If CLI provides a full quantization_config JSON, handle priority with config.json + if cli_is_full_config: + if model_quantization_config is not None: + if model_quantization_config != cli_quantization: + logger.warning( + "The quantization_config from --quantization argument " + "differs from the one in model's config.json. " + "Using config.json's quantization_config as it has higher priority. " + f"config.json: {model_quantization_config}, " + f"--quantization: {cli_quantization}" + ) + else: + # config.json has no quantization_config, use CLI's full config + quantization_config = cli_quantization + # 1.model_config.is_quantized # TODO(bukejiyu) model_config.is_quantized is v0 only need to be removed in future if model_config.model_format == "torch": - quantization_config = model_config.quantization_config if quantization_config is not None: model_config.is_quantized = True else: - quantization_config = model_config.quantization_config if not model_config.is_quantized: if quantization_config is not None: if "is_quantized" in quantization_config: @@ -84,11 +123,11 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader): quant_config_name = None - if quantization_config is not None: + if model_quantization_config is not None: quant_config_name = _get_offline_quant_config_name( - quantization_config, model_config.model_format == "torch", is_v1_loader + model_quantization_config, model_config.model_format == "torch", is_v1_loader ) - elif args.quantization is not None: + elif cli_quantization is not None and not cli_is_full_config: quantization_config = {} try: quantization_config.update(args.quantization) @@ -116,8 +155,11 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader): quantization_config["hadamard_block_size"] = 512 quantization_config["quantization"] = "mix_quant" quant_config_name = "mix_quant" + elif cli_quantization is not None and cli_is_full_config: + quant_config_name = quantization_config["quantization"] else: quant_config_name = None + if quant_config_name is None: quant_config = None else: @@ -127,6 +169,7 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader): quantization_config["is_quantized"] = True quant_cls = get_quantization_config(quant_config_name) quant_config = quant_cls.from_config(quantization_config) + return quant_config diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index a86170e0727..97f72e026fc 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -51,6 +51,19 @@ else: fp8_gemm_nt = None +# Detect whether fp8_gemm_nt accepts a 'bias' keyword argument +_fp8_gemm_nt_has_bias_kwarg = False +if fp8_gemm_nt is not None: + import inspect + + try: + _sig = inspect.signature(fp8_gemm_nt) + _fp8_gemm_nt_has_bias_kwarg = "bias" in _sig.parameters + except (ValueError, TypeError): + # pybind11 functions may not expose signatures via inspect; + # fall back to a cheap probe call to determine support. + pass + class BlockWiseFP8Config(QuantConfigBase): """ @@ -128,14 +141,22 @@ def deep_gemm_fp8_gemm_nt( sm_version = get_sm_version() if sm_version >= 100 and current_platform.is_cuda(): # disable_ue8m0_cast is default False for SM100 - fp8_gemm_nt( - (x, x_scale_tensor), - (layer_weight, layer_weight_scale_inv), - linear_out, - bias=bias, - ) + if _fp8_gemm_nt_has_bias_kwarg: + fp8_gemm_nt( + (x, x_scale_tensor), + (layer_weight, layer_weight_scale_inv), + linear_out, + bias=bias, + ) + else: + fp8_gemm_nt( + (x, x_scale_tensor), + (layer_weight, layer_weight_scale_inv), + linear_out, + ) + if bias is not None: + linear_out.add_(bias) else: - # disable_ue8m0_cast is default False for SM100 fp8_gemm_nt( (x, x_scale_tensor), (layer_weight, layer_weight_scale_inv), @@ -343,13 +364,13 @@ def apply(self, layer, x): else: x, x_scale_tensor = paddle.incubate.nn.functional.fp8_quant_blockwise( x, - using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0, + using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0 or fastdeploy.envs.FD_FP8_QUANT_WITH_POW2SCALE, output_scale_transpose=True, using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0, ) x_scale_tensor = x_scale_tensor.T[: x.shape[0], ...] - if get_sm_version() == 100 and current_platform.is_cuda(): + if get_sm_version() >= 100 and current_platform.is_cuda(): deep_gemm_fp8_gemm_nt( x, x_scale_tensor, @@ -370,5 +391,4 @@ def apply(self, layer, x): ) if layer.with_bias: linear_out = paddle.add(linear_out, layer.bias) - return linear_out diff --git a/fastdeploy/model_executor/layers/quantization/fp8_utils.py b/fastdeploy/model_executor/layers/quantization/fp8_utils.py index a5cd230f601..89b9467ecc6 100644 --- a/fastdeploy/model_executor/layers/quantization/fp8_utils.py +++ b/fastdeploy/model_executor/layers/quantization/fp8_utils.py @@ -14,46 +14,20 @@ # limitations under the License. """ -import importlib - import paddle import triton from paddleformers.utils.log import logger +from fastdeploy.model_executor.layers.utils import get_sm_version from fastdeploy.model_executor.ops.triton_ops import _per_token_group_quant_fp8 +from fastdeploy.model_executor.utils import try_import from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import per_token_group_fp8_quant -from ..utils import get_sm_version - - -def try_import(modules, name=None, fail_msg=None): - """ - try_import - """ - if not isinstance(modules, (list, tuple)): - modules = [modules] - - for m in modules: - assert isinstance(m, str), m - try: - m = importlib.import_module(m) - except ImportError: - m = None - - if m is not None: - if name is None: - return m - elif hasattr(m, name): - return getattr(m, name) - - if fail_msg is not None: - logger.warning(fail_msg) - -paddlefleet_ops = try_import(["paddlefleet.ops"]) +paddlefleet_ops = try_import(["paddlefleet.ops", "paddlefleet_ops"]) def load_deep_gemm(): @@ -67,11 +41,14 @@ def load_deep_gemm(): if current_platform.is_cuda(): if get_sm_version() >= 100: # SM100 should use PFCC DeepGemm - paddle.compat.enable_torch_proxy(scope={"deep_gemm"}) + paddle.enable_compat(scope={"deep_gemm"}) try: import logging - import paddlefleet.ops.deep_gemm as deep_gemm + try: + import paddlefleet.ops.deep_gemm as deep_gemm + except: + import paddlefleet_ops.deep_gemm as deep_gemm logging.getLogger().handlers.clear() logger.info("Detected sm100, use PaddleFleet DeepGEMM") diff --git a/fastdeploy/model_executor/layers/quantization/mxfp4.py b/fastdeploy/model_executor/layers/quantization/mxfp4.py index 9fa02866210..24ec38e696c 100644 --- a/fastdeploy/model_executor/layers/quantization/mxfp4.py +++ b/fastdeploy/model_executor/layers/quantization/mxfp4.py @@ -14,8 +14,6 @@ # limitations under the License. """ -import importlib -import importlib.util import math from enum import Enum from typing import Callable, Optional @@ -25,17 +23,18 @@ from fastdeploy import envs from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase -from fastdeploy.model_executor.utils import set_weight_attrs +from fastdeploy.model_executor.utils import has_flashinfer, set_weight_attrs from fastdeploy.platforms import current_platform if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch + from fastdeploy.utils import get_logger from ..moe import FusedMoE from .quant_base import QuantConfigBase, QuantMethodBase -paddle.compat.enable_torch_proxy(scope={"flashinfer"}) +paddle.enable_compat(scope={"flashinfer"}) logger = get_logger("config", "config.log") @@ -59,10 +58,6 @@ def check_device_capability(num): return False -def has_flashinfer(): - return importlib.util.find_spec("flashinfer") is not None - - def round_up(a, b): return ((a + b - 1) // b) * b diff --git a/fastdeploy/model_executor/layers/quantization/nvfp4.py b/fastdeploy/model_executor/layers/quantization/nvfp4.py index 196f6af6755..66e15750b75 100644 --- a/fastdeploy/model_executor/layers/quantization/nvfp4.py +++ b/fastdeploy/model_executor/layers/quantization/nvfp4.py @@ -33,7 +33,7 @@ from .quant_base import QuantConfigBase, QuantMethodBase -paddle.compat.enable_torch_proxy(scope={"flashinfer"}) +paddle.enable_compat(scope={"flashinfer"}) from fastdeploy.platforms import current_platform diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index af7203ed6f1..6c7286e2606 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -20,6 +20,7 @@ import paddle from paddle import nn +from fastdeploy import envs from fastdeploy.config import ModelConfig from fastdeploy.platforms import current_platform @@ -87,8 +88,13 @@ def __init__(self, rotary_dim, base, partial_rotary_factor): def __call__(self, position_ids): bsz, max_seq_len = position_ids.shape[:2] - inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) - freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + if envs.FD_ENABLE_RL == 1: + idx = paddle.arange(0, self.rotary_dim, 2, dtype=paddle.int64).astype(paddle.float32) + inv_freq = 1.0 / (self.base ** (idx / self.rotary_dim)) + freqs = paddle.outer(position_ids.astype(inv_freq.dtype), inv_freq) + else: + inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) # shape: [B, S, D/2] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2)) @@ -257,7 +263,7 @@ def forward( return query, key -class GptOssScalingRotaryEmbedding: +class YarnScalingRotaryEmbedding: def __init__( self, rotary_dim, @@ -334,10 +340,29 @@ def get_rope_impl( rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb = rotary_emb_layer(position_ids) elif architecture.startswith("Glm"): - rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor) + rope_scaling = getattr(model_config, "rope_scaling", None) + if ( + rope_scaling is not None + and isinstance(rope_scaling, dict) + and rope_scaling.get("rope_type", rope_scaling.get("type", "")) == "yarn" + and "factor" in rope_scaling + ): + yarn_rotary_dim = int(rotary_dim * partial_rotary_factor) if partial_rotary_factor < 1.0 else rotary_dim + rotary_emb_layer = YarnScalingRotaryEmbedding( + rotary_dim=yarn_rotary_dim, + base=base, + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + scale=rope_scaling["factor"], + mscale=rope_scaling.get("mscale", 1.0), + beta_fast=rope_scaling.get("beta_fast", 32), + beta_slow=rope_scaling.get("beta_slow", 1), + use_neox_rotary_style=False, + ) + else: + rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor) rotary_emb = rotary_emb_layer(position_ids) elif architecture.startswith("GptOss"): - rotary_emb_layer = GptOssScalingRotaryEmbedding( + rotary_emb_layer = YarnScalingRotaryEmbedding( rotary_dim=model_config.head_dim, base=model_config.rope_theta, original_max_position_embeddings=model_config.rope_scaling["original_max_position_embeddings"], diff --git a/fastdeploy/model_executor/layers/sample/logprobs.py b/fastdeploy/model_executor/layers/sample/logprobs.py index 559abdb298e..9f14c1d5e0e 100644 --- a/fastdeploy/model_executor/layers/sample/logprobs.py +++ b/fastdeploy/model_executor/layers/sample/logprobs.py @@ -123,7 +123,18 @@ def gather_logprobs( indices = token_ids top_logprobs = token_logprobs - return LogprobsTensors(indices.cpu(), top_logprobs.cpu(), token_ranks.cpu()) + if current_platform.is_cuda(): + indices_cpu = paddle.empty_like(indices, device="cpu").pin_memory() + top_logprobs_cpu = paddle.empty_like(top_logprobs, device="cpu").pin_memory() + token_ranks_cpu = paddle.empty_like(token_ranks, device="cpu").pin_memory() + indices_cpu.copy_(indices, False) + top_logprobs_cpu.copy_(top_logprobs, False) + token_ranks_cpu.copy_(token_ranks, False) + else: + indices_cpu = indices.cpu() + top_logprobs_cpu = top_logprobs.cpu() + token_ranks_cpu = token_ranks.cpu() + return LogprobsTensors(indices_cpu, top_logprobs_cpu, token_ranks_cpu) def build_output_logprobs( @@ -133,6 +144,7 @@ def build_output_logprobs( is_naive: bool = False, logprobs_mode: str = "default", compute_logprobs_fn: Optional[Callable] = None, + real_bsz: int = 0, ) -> Tuple[Optional[LogprobsTensors], Optional[paddle.Tensor]]: """ Build logprobs output for both NAIVE and speculative (MTP/Ngram) modes. @@ -153,69 +165,81 @@ def build_output_logprobs( scaling and top_p normalization. Used when logprobs_mode == "raw_logprobs". Returns: - tuple: (logprobs_tensors, cu_batch_token_offset) + tuple: (logprobs_tensors, cu_batch_token_offset, output_logits) """ num_logprobs = sampling_metadata.max_num_logprobs logprobs_tensors = None - cu_batch_token_offset = None - - if num_logprobs is None: - return logprobs_tensors, cu_batch_token_offset - real_bsz = share_inputs["seq_lens_this_time"].shape[0] + max_draft_token_num = share_inputs["accept_tokens"].shape[1] + max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0] if is_naive: # NAIVE mode: one token per request, logits are already correct output_logits = logits - token_ids = share_inputs["accept_tokens"][:real_bsz, 0] + token_ids = share_inputs["accept_tokens"][:max_occupied_slots, 0] else: # Speculative mode: extract target logits for accepted positions from fastdeploy.model_executor.layers.sample.ops import ( - speculate_get_target_logits, - ) - - batch_token_num = paddle.where( - share_inputs["seq_lens_encoder"][:real_bsz] != 0, - paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), - share_inputs["seq_lens_this_time"], - ).flatten() - - share_inputs["batch_token_num"] = batch_token_num - - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" + speculate_get_accept_tokens_and_logits, ) - cu_batch_token_offset = paddle.concat( - [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])] - ).astype("int32") - share_inputs["cu_batch_token_offset"] = cu_batch_token_offset output_logits = paddle.empty( - [share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], + [real_bsz * max_draft_token_num, logits.shape[1]], dtype=logits.dtype, ) - speculate_get_target_logits( + token_ids = paddle.full([real_bsz * max_draft_token_num], fill_value=0, dtype="int64") + + speculate_get_accept_tokens_and_logits( + token_ids, output_logits, logits, - cu_batch_token_offset, - ori_cu_batch_token_offset, + share_inputs["cu_batch_token_offset"], + share_inputs["cu_seqlens_q_output"], share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], share_inputs["accept_num"], + share_inputs["accept_tokens"], ) - idx = paddle.arange(share_inputs["accept_tokens"].shape[1], dtype="int32") - mask = idx < share_inputs["accept_num"].unsqueeze(1) - token_ids = paddle.masked_select(share_inputs["accept_tokens"], mask) + # Adapt for sampling mask + if num_logprobs is None: + return None, None, output_logits # Compute logprobs with temperature scaling and top_p normalization if logprobs_mode == "raw_logprobs": - raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata) + raw_logprobs = compute_logprobs_fn(output_logits, sampling_metadata, real_bsz) elif logprobs_mode == "raw_logits": raw_logprobs = output_logits.clone() else: raw_logprobs = F.log_softmax(output_logits, axis=-1) logprobs_tensors = gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + # output_logits use to compute sampling_mask + return logprobs_tensors, share_inputs["cu_batch_token_offset"], output_logits + - return logprobs_tensors, cu_batch_token_offset +def logprobs_renormalize_with_logz(logprobs: paddle.Tensor, logz, logprobs_tensors: LogprobsTensors): + """ + Renormalize logprobs to match truncated sampling distribution. + Args: + logprobs: tensor [B, max_num_logprobs + 1] + logz: [B], log(sum(probs in candidate set K)) for each request. + Can be np.ndarray or paddle.Tensor (CPU pinned memory). + logprobs_tensors: LogprobsTensors + """ + if isinstance(logz, paddle.Tensor): + logz = logz.astype(logprobs.dtype) + else: + logz = paddle.to_tensor(logz, dtype=logprobs.dtype) + # Renormalize: log π_masked = log π_full - log Z_K + # Only normalize valid candidates; padding positions use -inf + valid_mask = paddle.isfinite(logprobs) + normalized_logprobs = paddle.where( + valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf")) + ) + # Update logprobs_tensors with normalized values + return LogprobsTensors( + logprob_token_ids=logprobs_tensors.logprob_token_ids, + logprobs=normalized_logprobs, + selected_token_ranks=logprobs_tensors.selected_token_ranks, + ) diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 0d7f6915ab4..b51ecb84010 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -42,6 +42,7 @@ class SamplingMetadata: step_idx: paddle.Tensor top_p: paddle.Tensor + top_p_list: Optional[list] = None # only GPU used bad_words_token_len: Optional[paddle.Tensor] = None top_k: Optional[paddle.Tensor] = None @@ -66,3 +67,5 @@ class SamplingMetadata: # Add for HPU post-processing seq_lens_encoder: Optional[paddle.Tensor] = None seq_lens_decoder: Optional[paddle.Tensor] = None + # Add for keep sampling mask + keep_sampling_mask: Optional[bool] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/__init__.py b/fastdeploy/model_executor/layers/sample/ops/__init__.py index 911c1697497..3b272ede7b3 100644 --- a/fastdeploy/model_executor/layers/sample/ops/__init__.py +++ b/fastdeploy/model_executor/layers/sample/ops/__init__.py @@ -20,10 +20,14 @@ reasoning_phase_token_constraint, ) from .speculate_logprob_utils import ( - speculate_get_target_logits, + speculate_get_accept_tokens_and_logits, speculate_insert_first_token, ) -from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling +from .top_k_top_p_sampling import ( + dispatch_top_k_renorm_probs, + min_p_sampling, + top_k_top_p_sampling, +) __all__ = [ "apply_penalty_multi_scores", @@ -31,6 +35,7 @@ "reasoning_phase_token_constraint", "top_k_top_p_sampling", "min_p_sampling", - "speculate_get_target_logits", + "speculate_get_accept_tokens_and_logits", "speculate_insert_first_token", + "dispatch_top_k_renorm_probs", ] diff --git a/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py b/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py index 2caaf4892b8..df9bcaf6195 100644 --- a/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py +++ b/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py @@ -19,29 +19,35 @@ from fastdeploy.platforms import current_platform -def speculate_get_target_logits( +def speculate_get_accept_tokens_and_logits( + token_ids: paddle.Tensor, target_logits: paddle.Tensor, logits: paddle.Tensor, cu_batch_token_offset: paddle.Tensor, - ori_cu_batch_token_offset: paddle.Tensor, + cu_seqlens_q_output: paddle.Tensor, seq_lens_this_time: paddle.Tensor, seq_lens_encoder: paddle.Tensor, accept_num: paddle.Tensor, + accept_tokens: paddle.Tensor, ): """ - speculate_get_target_logits + speculate_get_accept_tokens_and_logits """ if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import speculate_get_target_logits + from fastdeploy.model_executor.ops.gpu import ( + speculate_get_accept_tokens_and_logits, + ) - speculate_get_target_logits( + speculate_get_accept_tokens_and_logits( + token_ids, target_logits, logits, cu_batch_token_offset, - ori_cu_batch_token_offset, + cu_seqlens_q_output, seq_lens_this_time, seq_lens_encoder, accept_num, + accept_tokens, ) else: raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index ff072e1a8ef..cc7c1c11277 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -34,6 +34,20 @@ def _reset_cuda_generator_for_determinism(): paddle.framework.core.default_cuda_generator(0).manual_seed(_DETERMINISTIC_RNG_SEED) +def dispatch_top_k_renorm_probs(probs, top_k): + try: + if current_platform.is_iluvatar(): + from fastdeploy.model_executor.ops.iluvatar import top_k_renorm_probs + else: + from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs + probs = top_k_renorm_probs(probs, top_k) + + except ImportError: + logger.warning("top_k sampling is not supported on current platform, skipping top_k filtering.") + + return probs + + def top_k_top_p_sampling( x: paddle.Tensor, top_p: paddle.Tensor, @@ -70,7 +84,6 @@ def top_k_top_p_sampling( """ top_p_class = envs.FD_SAMPLING_CLASS.lower() - topp_seed_device = None # In deterministic mode, reset CUDA generator offset before sampling. # paddle.tensor.top_p_sampling uses the global GPU generator offset even @@ -85,29 +98,17 @@ def top_k_top_p_sampling( _ = None else: if top_k_list and any(x > 0 for x in top_k_list): - try: - if current_platform.is_iluvatar(): - from fastdeploy.model_executor.ops.iluvatar import ( - top_k_renorm_probs, - ) - else: - from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs - x = top_k_renorm_probs(x, top_k) - except ImportError: - logger.warning("top_k sampling is not supported on current platform, skipping top_k filtering.") + x = dispatch_top_k_renorm_probs(x, top_k) if top_p_class == "air": _, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode) elif top_p_class == "base_non_truncated": - if topp_seed is not None: - topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype) - topp_seed_device.copy_(topp_seed, False) _, ids = paddle.tensor.top_p_sampling( x, top_p, threshold=threshold, - topp_seed=topp_seed_device, + topp_seed=topp_seed, seed=seed, k=k, mode="non-truncated", @@ -122,14 +123,11 @@ def top_k_top_p_sampling( _, ids = native_top_p_sampling(x, top_p) else: - if topp_seed is not None: - topp_seed_device = paddle.empty(shape=topp_seed.shape, dtype=topp_seed.dtype) - topp_seed_device.copy_(topp_seed, False) _, ids = paddle.tensor.top_p_sampling( x, top_p, threshold=threshold, - topp_seed=topp_seed_device, + topp_seed=topp_seed, seed=seed, k=k, mode="truncated", diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py new file mode 100644 index 00000000000..cc2fe4faafa --- /dev/null +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py @@ -0,0 +1,992 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +""" +Combined Top-K and Top-P Triton kernels. + +Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs +using Pivot-based Truncation and Selection" By Park et al. +(https://arxiv.org/abs/2602.01518) + +""" + +import warnings + +import paddle +from paddle.utils.deprecated import VisibleDeprecationWarning + +from fastdeploy.model_executor.ops.triton_ops.triton_utils import ( + enable_compat_on_triton_kernel, +) + +# Suppress the VisibleDeprecationWarning from use_triton_in_paddle that fires +# on every Triton kernel launch (paddle.device.cuda.current_stream / +# synchronize). In serving hot-paths this produces thousands of log lines per +# second and the I/O overhead alone can cause client-visible timeouts. +warnings.filterwarnings("ignore", category=VisibleDeprecationWarning) + +import triton # noqa: E402 +import triton.language as tl # noqa: E402 + +_TRITON_TABLE_CACHE: dict[tuple[paddle.device], tuple[paddle.Tensor, paddle.Tensor]] = {} +_TRITON_BUFFER_CACHE: dict[tuple[paddle.device, paddle.dtype, int], paddle.Tensor] = {} + +# fmt: off +_NORMAL_CDF_TO_SIGMA_TABLE = [ + 3.656, 3.650, 3.650, 3.650, 3.626, 3.626, 3.626, 3.514, 3.514, 3.503, + 3.503, 3.434, 3.434, 3.428, 3.428, 3.387, 3.380, 3.380, 3.376, 3.373, + 3.373, 3.356, 3.354, 3.354, 3.291, 3.249, 3.234, 3.214, 3.198, 3.198, + 3.185, 3.177, 3.177, 3.165, 3.164, 3.161, 3.138, 3.120, 3.115, 3.113, + 3.093, 3.066, 3.054, 3.043, 3.037, 3.023, 2.993, 2.991, 2.976, 2.970, + 2.952, 2.946, 2.932, 2.908, 2.902, 2.895, 2.886, 2.874, 2.861, 2.844, + 2.836, 2.810, 2.801, 2.790, 2.784, 2.779, 2.767, 2.757, 2.745, 2.733, + 2.723, 2.716, 2.693, 2.678, 2.671, 2.656, 2.649, 2.629, 2.611, 2.595, + 2.592, 2.585, 2.574, 2.550, 2.543, 2.534, 2.521, 2.518, 2.497, 2.485, + 2.468, 2.450, 2.441, 2.430, 2.412, 2.402, 2.389, 2.383, 2.377, 2.364, + 2.349, 2.338, 2.332, 2.319, 2.310, 2.301, 2.282, 2.274, 2.266, 2.250, + 2.242, 2.236, 2.226, 2.215, 2.207, 2.196, 2.179, 2.171, 2.162, 2.147, + 2.135, 2.121, 2.109, 2.095, 2.085, 2.073, 2.063, 2.045, 2.030, 2.016, + 2.003, 1.992, 1.983, 1.972, 1.960, 1.949, 1.940, 1.928, 1.912, 1.897, + 1.881, 1.869, 1.854, 1.838, 1.824, 1.807, 1.792, 1.779, 1.764, 1.751, + 1.739, 1.726, 1.711, 1.697, 1.685, 1.668, 1.652, 1.636, 1.622, 1.603, + 1.585, 1.568, 1.551, 1.534, 1.513, 1.499, 1.480, 1.464, 1.441, 1.422, + 1.394, 1.373, 1.347, 1.320, 1.296, 1.270, 1.246, 1.219, 1.190, 1.163, + 1.135, 1.104, 1.073, 1.041, 1.006, 0.969, 0.931, 0.894, 0.851, 0.806, + 0.757, 0.702, 0.643, 0.574, 0.498, 0.405, 0.288, 0.134, -0.110, -3.813 +] + +_PERCENTILE_TO_STD_TABLE = [ + 2.576, 2.319, 2.178, 2.064, 1.968, 1.892, 1.819, 1.757, 1.708, 1.659, + 1.616, 1.568, 1.526, 1.492, 1.456, 1.420, 1.382, 1.342, 1.309, 1.280, + 1.249, 1.221, 1.193, 1.169, 1.145, 1.121, 1.095, 1.073, 1.050, 1.030, + 1.008, 0.987, 0.966, 0.945, 0.926, 0.910, 0.891, 0.871, 0.854, 0.837, + 0.819, 0.803, 0.784, 0.767, 0.753, 0.734, 0.719, 0.702, 0.690, 0.675, + 0.658, 0.640, 0.625, 0.609, 0.595, 0.578, 0.564, 0.550, 0.537, 0.521, + 0.509, 0.495, 0.481, 0.466, 0.453, 0.439, 0.424, 0.410, 0.397, 0.383, + 0.370, 0.356, 0.343, 0.330, 0.316, 0.302, 0.289, 0.274, 0.261, 0.247, + 0.235, 0.223, 0.209, 0.196, 0.184, 0.172, 0.159, 0.149, 0.137, 0.124, + 0.112, 0.100, 0.086, 0.074, 0.062, 0.050, 0.035, 0.023, 0.009, -0.003, + -0.015, -0.027, -0.039, -0.052, -0.063, -0.074, -0.085, -0.097, -0.109, -0.122, + -0.134, -0.147, -0.158, -0.171, -0.184, -0.196, -0.210, -0.223, -0.235, -0.248, + -0.261, -0.275, -0.289, -0.302, -0.317, -0.328, -0.341, -0.353, -0.368, -0.382, + -0.396, -0.410, -0.426, -0.439, -0.452, -0.465, -0.480, -0.493, -0.507, -0.521, + -0.537, -0.551, -0.568, -0.582, -0.597, -0.614, -0.628, -0.643, -0.658, -0.673, + -0.691, -0.706, -0.721, -0.738, -0.754, -0.769, -0.789, -0.808, -0.824, -0.838, + -0.857, -0.877, -0.893, -0.912, -0.929, -0.947, -0.965, -0.983, -1.003, -1.027, + -1.050, -1.070, -1.092, -1.117, -1.139, -1.162, -1.189, -1.216, -1.241, -1.272, + -1.300, -1.330, -1.367, -1.404, -1.441, -1.485, -1.523, -1.564, -1.607, -1.658, + -1.710, -1.778, -1.832, -1.901, -1.978, -2.068, -2.174, -2.325, -2.577, -3.813 +] +# fmt: on + + +@triton.jit +def _update_min_larger_stats(data, above_mask, min_larger, num_min_larger, sentinel): + """Update running (min, count) of values above a pivot across tiles. + + Tracks the smallest value strictly above a pivot and how many times + it occurs. Called once per tile per pivot; the running state is + carried across tiles via `min_larger` / `num_min_larger`. + + Merge rule: + - tile min < running min → replace both + - tile min == running min → accumulate count + - tile min > running min → keep running values + """ + tile_min = tl.min(tl.where(above_mask, data, sentinel)) + tile_eq = above_mask & (tl.abs(data - tile_min) < 1e-9) + tile_cnt = tl.sum(tile_eq) + is_new = tile_min < min_larger + is_same = tl.abs(tile_min - min_larger) < 1e-9 + num_min_larger = tl.where(is_new, tile_cnt, num_min_larger + tile_cnt * is_same) + min_larger = tl.minimum(min_larger, tile_min) + return min_larger, num_min_larger + + +@enable_compat_on_triton_kernel +@triton.jit +def _topk_topp_kernel( + LOGITS, + BUFFER, + MASK_OUT, + PERCENTILE_TO_STD_TABLE, + NORMAL_CDF_TO_SIGMA_TABLE, + K, + P, + BATCH_SIZE, + VOCAB_SIZE: tl.constexpr, + MASK_VALUE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_TRUNC: tl.constexpr, + TOPK_ENABLED: tl.constexpr, + TOPP_ENABLED: tl.constexpr, + WRITE_MASK: tl.constexpr, +): + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + for row_id in tl.range(pid, BATCH_SIZE, num_programs): + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + BUFFER_ROW = BUFFER + pid * VOCAB_SIZE + + final_pivot = -float("inf") + duplicate_logit = float("inf") + num_duplicate_logit = tl.zeros((), dtype=tl.uint32) + num_keep = tl.zeros((), dtype=tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + max_logit = -float("inf") + min_logit = float("inf") + + if TOPK_ENABLED: + k = tl.load(K + row_id) + if k < VOCAB_SIZE: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=-float("inf")) + # Exclude -inf values (e.g. from grammar bitmasks) from + # statistics to avoid NaN in pivot computation. + finite_mask = (logits_blk0 > -float("inf")) & mask_n + num_finite = tl.sum(finite_mask) + finite_logits = tl.where(finite_mask, logits_blk0, 0.0) + avg_logit = tl.where(num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0) + sq_avg_logit = tl.where( + num_finite > 0, + tl.sum(finite_logits * finite_logits) / num_finite, + 0.0, + ) + std_logit = tl.sqrt(tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0)) + + # Calculate outlier pivot t for Gaussian sigma-truncation + percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) + percentile = tl.minimum(percentile, 199) + sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) + sigma = sigma + tl.abs(sigma) * -0.15 + outlier_pivot = avg_logit + std_logit * sigma + num_outliers = tl.zeros((), dtype=tl.uint32) + + # First pass: compute max and min logits and gather outliers + num_finite_total = tl.zeros((), dtype=tl.uint32) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + # Exclude -inf from min to keep binary search bounds + # finite (avoids NaN pivots). + finite_blk_mask = logits_blk > -float("inf") + finite_blk = tl.where(finite_blk_mask, logits_blk, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(finite_blk)) + num_finite_total += tl.sum(finite_blk_mask & mask_n) + + outlier_mask = (logits_blk > outlier_pivot) & mask_n + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) + + # If no finite logits exist (all -inf), clamp min to + # max so the search converges to -inf (no masking). + min_logit = tl.minimum(min_logit, max_logit) + + # Second passes: Ternary search for pivots + num_iters = 0 + k_pivot = float("inf") + k_pivots_num = tl.zeros((), dtype=tl.uint32) + min_larger = float("inf") + num_min_larger = tl.zeros((), dtype=tl.uint32) + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + found_pivot = 0 + while found_pivot == 0: + k_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # Single fused pass: compute k_pivots_num, + # min_larger, and num_min_larger together to avoid + # a second data scan. See _update_min_larger_stats + # for the tile-level merge logic. + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")) + + above_0 = logits_blk2 > k_pivot_0 + above_1 = logits_blk2 > k_pivot_1 + k_pivots_num_0 += tl.sum(above_0) + k_pivots_num_1 += tl.sum(above_1) + + min_larger_0, num_min_larger_0 = _update_min_larger_stats( + logits_blk2, + above_0, + min_larger_0, + num_min_larger_0, + float("inf"), + ) + min_larger_1, num_min_larger_1 = _update_min_larger_stats( + logits_blk2, + above_1, + min_larger_1, + num_min_larger_1, + float("inf"), + ) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k and k_pivots_num_0 - num_min_larger_0 < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + found_pivot = 1 + if k_pivots_num_1 >= k and k_pivots_num_1 - num_min_larger_1 < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + found_pivot = 1 + + # Update range + if k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + + num_iters += 1 + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: + k_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + else: + # If top-k outlier gathering failed, search whole logit space + max_range = max_logit + min_range = min_logit + found_pivot = 0 + while found_pivot == 0: + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # Single fused pass over full vocab (same approach + # as the buffer path above). + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + + above_0 = logits_blk2 > k_pivot_0 + above_1 = logits_blk2 > k_pivot_1 + k_pivots_num_0 += tl.sum(above_0) + k_pivots_num_1 += tl.sum(above_1) + + min_larger_0, num_min_larger_0 = _update_min_larger_stats( + logits_blk2, + above_0, + min_larger_0, + num_min_larger_0, + float("inf"), + ) + min_larger_1, num_min_larger_1 = _update_min_larger_stats( + logits_blk2, + above_1, + min_larger_1, + num_min_larger_1, + float("inf"), + ) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k and k_pivots_num_0 - num_min_larger_0 < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + found_pivot = 1 + if k_pivots_num_1 >= k and k_pivots_num_1 - num_min_larger_1 < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + found_pivot = 1 + + # Update range + if k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + + num_iters += 1 + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: + k_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = min_larger + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - (k_pivots_num - k) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k only path. If there are fewer finite values + # than k (e.g. grammar mask), keep everything. + final_pivot = k_pivot if num_finite_total > k else -float("inf") + + if TOPP_ENABLED and num_finite_total > k: + #### TOP-P SAMPLING AFTER TOP-K #### + p = tl.load(P + row_id) + if p < 1.0: + min_logit = k_pivot + sum_exp_logits = 0.0 + num_outliers_2 = tl.zeros((), dtype=tl.uint32) + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + # Third pass: Calculate exp logits and sum, gather outliers + if num_outliers > k: + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load( + BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float("inf"), + ) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + + # Duplicate logit handling for Top-k + if num_keep < num_duplicate_logit: + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float("inf")) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load( + BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float("inf"), + ) + + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + else: + # If top-k outlier gathering failed, + # retry gathering using top-k pivot + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load( + LOGITS_ROW + offs_n, + mask=mask_n, + other=-float("inf"), + ) + + outlier_mask = (probs_blk > min_logit) & mask_n + + # Duplicate logit handling for Top-k + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float("inf")) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers_2, + tl.int32, + ) + num_outliers_2 += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + search_range = tl.cast(num_outliers_2, tl.int32) + search_iters = tl.cast( + (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Fifth passes: Search for p_pivot + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and (p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p): + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and (p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p): + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k + Top-p path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + + if TOPP_ENABLED and final_pivot == -float("inf"): + #### STANDALONE TOP-P SAMPLING #### + p = tl.load(P + row_id) + if p < 1.0: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=-float("inf")) + # Exclude -inf values (e.g. from grammar bitmasks) from + # statistics to avoid NaN in pivot computation. + finite_mask = (logits_blk0 > -float("inf")) & mask_n + num_finite = tl.sum(finite_mask) + finite_logits = tl.where(finite_mask, logits_blk0, 0.0) + avg_logit = tl.where(num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0) + sq_avg_logit = tl.where( + num_finite > 0, + tl.sum(finite_logits * finite_logits) / num_finite, + 0.0, + ) + std_logit = tl.sqrt(tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0)) + max_sample = avg_logit + std_logit * 10.0 + sum_exp_logits = 0.0 + + # First pass: compute max and min logits and sum_exp_logits + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + # Exclude -inf from min to keep binary search bounds + # finite (avoids NaN pivots). + finite_blk = tl.where(logits_blk > -float("inf"), logits_blk, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(finite_blk)) + + probs_blk = tl.exp(logits_blk - max_sample) + probs_blk = tl.where(mask_n, probs_blk, 0.0) + sum_exp_logits += tl.sum(probs_blk) + + # If no finite logits exist (all -inf), clamp min to + # max so the search converges to -inf (no masking). + min_logit = tl.minimum(min_logit, max_logit) + + idx = tl.cast(p * 200, tl.int32) + idx = tl.maximum(0, tl.minimum(idx, 199)) + sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) + sigma = sigma + tl.abs(sigma) * -0.25 + outlier_pivot = avg_logit + std_logit * sigma + + outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits + sum_outlier_probs = 0.0 + num_outliers = tl.zeros((), dtype=tl.uint32) + + # Second pass: Calculate softmax and gather outliers + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + max_range = tl.exp(max_logit - max_sample) / sum_exp_logits + min_range = tl.exp(min_logit - max_sample) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Third pass: Search for p_pivot + if sum_outlier_probs > p: + min_range = outlier_prob + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + else: + # Re-populate the buffer with full softmax probabilities + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-p only path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample + + # Sixth pass: Apply mask and store final output. + # If the pivot >= max logit (or is NaN), no token would + # survive the strict `>` keep_mask. Skip masking. + # Using `not <` instead of `>=` so that NaN is also caught. + if not (final_pivot < max_logit): + final_pivot = -float("inf") + elif final_pivot != -float("inf"): + if WRITE_MASK: + MASK_ROW = MASK_OUT + row_id * VOCAB_SIZE + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + keep_mask = (logits_blk > final_pivot) & mask_n + + # Duplicate logit handling + if num_keep < num_duplicate_logit: + duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + num_kept += tl.sum(duplicate_keep_mask) + keep_mask = keep_mask & (~duplicate_remove_mask) + + logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + if WRITE_MASK: + tl.store(MASK_ROW + offs_n, keep_mask, mask=mask_n) + + # When no masking was applied (final_pivot == -inf), all tokens are kept. + if WRITE_MASK and final_pivot == -float("inf"): + MASK_ROW = MASK_OUT + row_id * VOCAB_SIZE + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + tl.store(MASK_ROW + offs_n, mask_n, mask=mask_n) + + +def apply_top_k_top_p_triton( + logits: paddle.Tensor, + k: paddle.Tensor | None, + p: paddle.Tensor | None, + mask_value: float = float("-inf"), + return_mask: bool = False, +) -> paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor]: + """ + Apply combined top-k and top-p masking using Triton. + + Top-k is applied first (by logit value), then top-p is applied + to the remaining k values (by probability). + + Args: + logits: [batch_size, vocab_size] float32 tensor, modified in-place + k: [batch_size] int32 tensor of top-k values per row, or None to disable top-k + p: [batch_size] float32 tensor of top-p values per row (0 to 1), + or None to disable top-p + mask_value: Value for masked positions (default: -inf) + return_mask: If True, also return a bool mask [batch_size, vocab_size] + where True = retained token. The mask is computed inside the kernel + with zero extra memory bandwidth cost. + + Returns: + logits if return_mask is False, else (logits, mask). + """ + assert logits.ndim == 2 + assert logits.dtype == paddle.float32 + + batch_size, vocab_size = logits.shape + + topk_enabled = k is not None + topp_enabled = p is not None + + if batch_size == 0 or not (topk_enabled or topp_enabled): + if return_mask: + mask = paddle.ones(logits.shape, dtype=paddle.bool) + return logits, mask + return logits + + if k is not None: + assert k.ndim == 1 and k.shape[0] == batch_size + k_ptr = k.to(paddle.int32) + else: + k_ptr = logits # Dummy pointer (won't be read) + + if p is not None: + assert p.ndim == 1 and p.shape[0] == batch_size + p_ptr = p.to(paddle.float32) + else: + p_ptr = logits # Dummy pointer (won't be read) + + num_sm = paddle.device.cuda.get_device_properties(logits.device.index).multi_processor_count + NUM_PROGRAMS = min(num_sm, batch_size) + + # Cache per-Triton Program buffer on each device. + buf_key = (logits.device, logits.dtype, vocab_size) + buffer = _TRITON_BUFFER_CACHE.get(buf_key) + if buffer is None or buffer.shape[0] < NUM_PROGRAMS: + size = min(triton.next_power_of_2(NUM_PROGRAMS), num_sm) + buffer = paddle.empty((size, vocab_size), dtype=logits.dtype) + _TRITON_BUFFER_CACHE[buf_key] = buffer + if buffer.shape[0] > NUM_PROGRAMS: + buffer = buffer[:NUM_PROGRAMS] + + # Allocate mask output if requested. + write_mask = return_mask + if write_mask: + mask_out = paddle.empty(logits.shape, dtype=paddle.int8) + else: + mask_out = logits # Dummy pointer (won't be written) + + # Cache lookup table entries on each device. + tables = _TRITON_TABLE_CACHE.get(logits.device) + if tables is None: + normal_cdf_to_sigma_table = paddle.to_tensor( + _NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place + ) + percentile_to_std_table = paddle.to_tensor(_PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place) + _TRITON_TABLE_CACHE[logits.device] = ( + normal_cdf_to_sigma_table, + percentile_to_std_table, + ) + else: + normal_cdf_to_sigma_table, percentile_to_std_table = tables + + _topk_topp_kernel[(NUM_PROGRAMS,)]( + logits, + buffer, + mask_out, + percentile_to_std_table, + normal_cdf_to_sigma_table, + k_ptr, + p_ptr, + BATCH_SIZE=batch_size, + MASK_VALUE=mask_value, + VOCAB_SIZE=vocab_size, + BLOCK_SIZE=8192, + BLOCK_SIZE_TRUNC=4096, + TOPK_ENABLED=topk_enabled, + TOPP_ENABLED=topp_enabled, + WRITE_MASK=write_mask, + ) + + if return_mask: + return logits, mask_out.astype(paddle.bool) + return logits + + +@enable_compat_on_triton_kernel +@triton.jit +def _seeded_gumbel_kernel( + OUT_ptr, + SEEDS_ptr, + stride_out_batch, + VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Generate -log(u) with per-row Philox seeds, fully on GPU.""" + pid = tl.program_id(0) + seed = tl.load(SEEDS_ptr + pid) + for start in tl.range(0, VOCAB_SIZE, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < VOCAB_SIZE + u = tl.rand(seed, offsets) + u = tl.maximum(u, 1e-10) + q = -tl.log(u) + tl.store(OUT_ptr + pid * stride_out_batch + offsets, q, mask=mask) + + +def seeded_gumbel_noise(probs: paddle.Tensor, seeds: paddle.Tensor) -> paddle.Tensor: + """ + Generate Gumbel noise q = -log(u) with per-row Philox seeds on GPU. + + Args: + probs: [batch_size, vocab_size] — used only for shape/dtype. + seeds: [batch_size] int64 per-request seeds (GPU). + + Returns: + q: [batch_size, vocab_size] float tensor of Gumbel noise. + """ + batch_size, vocab_size = probs.shape + q = paddle.empty_like(probs) + BLOCK_SIZE = min(triton.next_power_of_2(vocab_size), 4096) + _seeded_gumbel_kernel[(batch_size,)]( + q, + seeds, + q.strides[0], + VOCAB_SIZE=vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + return q + + +def reset_buffer_cache(): + _TRITON_BUFFER_CACHE.clear() + _TRITON_TABLE_CACHE.clear() + paddle.accelerator.empty_cache() diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 08a33c11096..cac8e7249a8 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -19,6 +19,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, List, Optional +import numpy as np import paddle import paddle.nn.functional as F from paddle import nn @@ -26,7 +27,7 @@ from fastdeploy import envs from fastdeploy.config import FDConfig -from fastdeploy.envs import FD_FILL_BITMASK_BATCH +from fastdeploy.envs import FD_FILL_BITMASK_BATCH, FD_SAMPLING_CLASS from fastdeploy.logger.deterministic_logger import _record_logits_diagnostic from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase from fastdeploy.model_executor.layers.sample.early_stopper import ( @@ -40,11 +41,16 @@ from fastdeploy.model_executor.layers.sample.ops import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, + dispatch_top_k_renorm_probs, min_p_sampling, reasoning_phase_token_constraint, speculate_insert_first_token, top_k_top_p_sampling, ) +from fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton import ( + apply_top_k_top_p_triton, + seeded_gumbel_noise, +) from fastdeploy.platforms import current_platform from fastdeploy.reasoning import ReasoningParser from fastdeploy.spec_decode import SpecMethod, VerifyStrategy @@ -53,10 +59,78 @@ if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import ( build_sampling_params, + build_sampling_params_logprob, naive_update_model_status, ) +def _apply_triton_top_k_top_p( + logits: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, + top_k_list: Optional[list] = None, + return_mask: bool = False, +) -> paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor]: + """ + Apply combined top-k/top-p masking on logits using the Triton kernel. + Masked positions are set to -inf in-place. Call this BEFORE softmax. + + Args: + return_mask: If True, return (logits, mask) where mask is a bool + tensor [B, V] computed inside the Triton kernel (zero extra cost). + + Returns: + logits if return_mask is False, else (logits, mask). + """ + if top_p is None and top_k is None: + return logits + batch_size = logits.shape[0] + + top_p = top_p[:batch_size].squeeze(axis=-1) + + has_top_k = top_k_list and any(x > 0 for x in top_k_list) + if has_top_k: + top_k = top_k[:batch_size].squeeze(axis=-1) + else: + top_k = None + + return apply_top_k_top_p_triton(logits.astype("float32"), k=top_k, p=top_p, return_mask=return_mask) + + +def _random_sample( + probs: paddle.Tensor, + topp_seed: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + """ + Sample from probabilities using the Gumbel-max trick. + + Equivalent to multinomial sampling but avoids CPU-GPU synchronization. + When ``topp_seed`` is provided and Triton is available, a Triton kernel + generates per-row deterministic Gumbel noise using Philox PRNG entirely + on GPU, eliminating the Python for-loop and CPU-GPU sync overhead. + + Args: + probs: [batch_size, vocab_size] float32 probabilities. + topp_seed: [batch_size, 1] int64 per-request seeds, or None. + + Returns: + Token ids of shape [batch_size, 1]. + + Reference: vllm/v1/sample/ops/topk_topp_sampler.py::random_sample + """ + # Sample from Exp(1): q = -log(u), u ~ Uniform(0, 1) + if topp_seed is not None: + seeds = topp_seed[: probs.shape[0]].reshape([-1]) + if not seeds.place.is_gpu_place(): + seeds = seeds.cuda() + q = seeded_gumbel_noise(probs, seeds) + else: + u = paddle.uniform(probs.shape, dtype=probs.dtype, min=0.0, max=1.0) + q = -paddle.log(u.clip(min=1e-10)) + # Gumbel-max: argmax(probs / q) is equivalent to multinomial(probs) + return (probs / q).argmax(axis=-1).reshape([-1, 1]) + + def top_p_normalize_probs_paddle( probs: paddle.Tensor, top_ps: paddle.Tensor, @@ -105,6 +179,202 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le return top_p_padding, top_k_padding, topp_seed +def _compute_sampling_mask( + probs: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, + top_k_list: Optional[list] = None, +) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, int]: + """ + Compute a combined top-k + top-p (nucleus) sampling mask — GPU only, + no D2H transfer or CPU sync. + + Processing order: + 1. Sort probs descending once (shared by top-k and top-p stages). + 2. top-k mask — zero out positions beyond top_k[i] in sorted order. + 3. top-k renorm — renormalise in-place after truncation. + 4. top-p mask — cumsum on the already-sorted renormed probs; no + second argsort needed. + 5. intersect — AND of the two masks, applied on GPU before D2H. + + Either filter can be disabled: + - top-k is skipped when top_k_list is None or all values <= 0. + - top-p[i] >= 1.0 → keep all tokens for that request. + + Args: + probs: [num_reqs, vocab_size] softmax probabilities (GPU). + top_p: [num_reqs, 1] top-p threshold per request (GPU). + top_k: [num_reqs, 1] top-k per request (GPU, int); 0 = disabled. + top_k_list: Python list of top-k values; used to decide whether any + top-k filtering is needed at all. + + Returns: + Tuple of (indices_window, mask_window, logz_per_batch, real_bsz): + - indices_window: [B, max_k] GPU int64 tensor of sorted vocab indices. + - mask_window: [B, max_k] GPU bool tensor, True = retained. + - logz_per_batch: [B] GPU float32 tensor, log(Z_K) per request. + - real_bsz: int, the batch size. + """ + real_bsz = probs.shape[0] + vocab_size = probs.shape[1] + top_p = top_p[:real_bsz] # [B, 1] + + has_top_k = top_k is not None and top_k_list and any(x > 0 for x in top_k_list) + + # ------------------------------------------------------------------ + # Stage 1: single sort — descending by probability. + # sorted_indices / sorted_probs are reused by both top-k and top-p. + # ------------------------------------------------------------------ + sorted_indices = paddle.argsort(probs, axis=-1, descending=True) # [B, V] + sorted_probs = paddle.take_along_axis(probs, sorted_indices, axis=-1) # [B, V] + + # ------------------------------------------------------------------ + # Stage 2: top-k mask (GPU, no D2H) + # ------------------------------------------------------------------ + if has_top_k: + top_k = top_k[:real_bsz] # [B, 1] + # top_k == 0 means "disabled" → keep all columns for that row. + effective_k = paddle.where(top_k > 0, top_k, paddle.full_like(top_k, vocab_size)) + + # Relax: also keep positions whose prob ties with the k-th element. + # boundary index (0-based) = effective_k - 1, clamped to [0, V-1]. + k_idx = (effective_k - 1).clip(min=0).squeeze(-1).astype("int64") # [B] k-th index + batch_idx = paddle.arange(k_idx.shape[0], dtype="int64") # [B] bs index + boundary_prob = sorted_probs[batch_idx, k_idx].unsqueeze(-1) # [B, 1] min_probs in topk candidates + topk_mask = sorted_probs >= boundary_prob # [B, V] True = retained by top-k + + # Zero out tail, then renorm row-wise. + masked_sorted_probs = paddle.where(topk_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + row_sums = masked_sorted_probs.sum(axis=-1, keepdim=True).clip(min=1e-9) + renorm_sorted_probs = masked_sorted_probs / row_sums # [B, V] + else: + topk_mask = None + renorm_sorted_probs = sorted_probs + + # ------------------------------------------------------------------ + # Stage 3: top-p mask on already-sorted renormed probs (no re-sort). + # ------------------------------------------------------------------ + cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V] + topp_mask = (cum_probs - renorm_sorted_probs) <= top_p # [B, V] + # When top_p[i] >= 1.0, keep the entire row. + topp_mask = paddle.where( + (top_p >= 1.0).expand_as(topp_mask), + paddle.ones_like(topp_mask), + topp_mask, + ) + + # Extend mask to cover sort tie-breaking: include all tokens whose + # probability >= the boundary token's probability (last retained + # in sorted order). In descending-sorted probs this just extends + # the contiguous True block by the run of equal-prob tokens. + k_per_row = topp_mask.astype("int32").sum(axis=-1, keepdim=True) # [B,1] + # boundary_idx = last True position (k-1), clamp for safety + boundary_idx = (k_per_row - 1).clip(min=0) # [B, 1] + boundary_prob = paddle.take_along_axis( + renorm_sorted_probs, + boundary_idx, + axis=-1, + ) # [B, 1] + topp_mask = topp_mask | (renorm_sorted_probs >= boundary_prob) + + # ------------------------------------------------------------------ + # Stage 4: intersect on GPU, then minimal D2H. + # ------------------------------------------------------------------ + final_mask = topk_mask & topp_mask if has_top_k else topp_mask # [B, V] + + k_per_row = final_mask.astype("int32").sum(axis=-1) # [B] + max_k = k_per_row.max().reshape([-1]) # [1], stays on GPU + + # ------------------------------------------------------------------ + # Stage 5: compute logZ_K for renormalization + # Z_K = sum(probs[i] * final_mask[i]) for each request i + # logZ_K = log(Z_K), with small constant to avoid log(0) + # ------------------------------------------------------------------ + candidate_probs = paddle.where(final_mask, sorted_probs, paddle.zeros_like(sorted_probs)) + z_k = candidate_probs.sum(axis=-1) # [B] + logz_per_batch = paddle.log(z_k + 1e-10) # [B], GPU + + # Slice only the leading max_k columns on GPU — typically max_k << vocab_size. + # All outputs stay on GPU; D2H is deferred to save_output via async copy_. + indices_window = sorted_indices.slice([1], [0], max_k) # [B, max_k] + mask_window = final_mask.slice([1], [0], max_k) # [B, max_k] + + return indices_window, mask_window, logz_per_batch, real_bsz + + +def _extract_sparse_indices( + indices_window_cpu: np.ndarray, + mask_window_cpu: np.ndarray, + real_bsz: int, +) -> List[np.ndarray]: + """ + Extract per-request sparse retained-token indices from CPU numpy arrays. + + This is the CPU-side counterpart of _compute_sampling_mask. + + Args: + indices_window_cpu: [B, max_k] int64 numpy array of sorted vocab indices. + mask_window_cpu: [B, max_k] bool numpy array, True = retained. + real_bsz: batch size (number of rows to process). + + Returns: + List of length real_bsz; element i is a 1-D int64 numpy array of + retained vocab indices for request i. + """ + return [indices_window_cpu[i, mask_window_cpu[i]] for i in range(real_bsz)] + + +def _sample_from_probs(probs, sampling_metadata, top_p=None, top_k=None, topp_seed=None): + """Sample next tokens from probability distributions with optional top-k and top-p filtering. + + When ``top_p_list`` is all 1.0 (no top-p filtering needed), uses + :func:`_random_sample` with an optional top-k renormalization pass via + :func:`dispatch_top_k_renorm_probs`. Otherwise dispatches through + :func:`top_k_top_p_sampling` to apply joint top-k/top-p constraints. + + Args: + probs: [token_num, vocab_size] float32 probability tensor (normalized logits). + sampling_metadata: Metadata carrying top_p, top_k, seed, top_k_list, + and top_p_list for the current batch of requests. + top_p: Override for per-row top-p values, shape [token_num, 1] or None. + top_k: Override for per-row top-k values, shape [token_num, 1] or None. + topp_seed: Override for per-row random seeds, shape [token_num, 1] or None. + + Returns: + Sampled token ids of shape [token_num, 1]. + """ + token_num = probs.shape[0] + if top_p is None: + top_p = sampling_metadata.top_p + if top_k is None: + top_k = sampling_metadata.top_k + if topp_seed is None: + topp_seed = sampling_metadata.seed + top_k_list = sampling_metadata.top_k_list + top_p_list = sampling_metadata.top_p_list + need_top_k_sampling = False + need_top_p_sampling = True + if top_k_list is not None: + top_k_list = top_k_list[:token_num] + need_top_k_sampling = any(k > 0 for k in top_k_list) + if top_p_list is not None: + top_p_list = top_p_list[:token_num] + need_top_p_sampling = any(p != 1.0 for p in top_p_list) + if not need_top_p_sampling and current_platform.is_cuda() and envs.FD_ENABLE_TOP_P_ONE_OPT: + if need_top_k_sampling: + probs = dispatch_top_k_renorm_probs(probs, top_k) + next_tokens = _random_sample(probs, topp_seed=topp_seed) + else: + _, next_tokens = top_k_top_p_sampling( + probs, + top_p, + top_k, + top_k_list, + topp_seed=topp_seed, + ) + return next_tokens + + class GuidedDecoding: """ processor for guided decoding. @@ -547,6 +817,16 @@ def forward_cuda( elif self.logprobs_mode == "processed_logits": raw_logprobs = logits.clone() + # Triton path: mask logits in-place BEFORE softmax (no probs→log round-trip). + if FD_SAMPLING_CLASS.lower() == "triton": + logits = _apply_triton_top_k_top_p( + logits, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + return_mask=False, + ) + probs = F.softmax(logits) # Record post-penalty logits and probs MD5 for determinism diagnosis @@ -554,13 +834,34 @@ def forward_cuda( _record_logits_diagnostic(logits, tag="post_penalty_logits", probs=probs) probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) - _, next_tokens = top_k_top_p_sampling( - probs, - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.top_k_list, - topp_seed=sampling_metadata.seed, - ) + + # Compute sampling mask BEFORE top_k_top_p_sampling modifies probs. + # All GPU ops; D2H is done via async copy_ with event sync in save_output. + sampling_mask = None + logz_per_batch = None + if sampling_metadata.keep_sampling_mask: + indices_window_gpu, mask_window_gpu, logz_per_batch, mask_bsz = _compute_sampling_mask( + probs, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + ) + # Allocate CPU pinned tensors and async copy + indices_window_cpu = paddle.empty_like( + indices_window_gpu, dtype=indices_window_gpu.dtype, device="cpu" + ).pin_memory() + mask_window_cpu = paddle.empty_like( + mask_window_gpu, dtype=mask_window_gpu.dtype, device="cpu" + ).pin_memory() + indices_window_cpu.copy_(indices_window_gpu, False) + mask_window_cpu.copy_(mask_window_gpu, False) + # Store deferred GPU→CPU data; sparse extraction happens in save_output + sampling_mask = (indices_window_cpu, mask_window_cpu, mask_bsz) + + if FD_SAMPLING_CLASS.lower() == "triton": + next_tokens = _random_sample(probs, topp_seed=sampling_metadata.seed) + else: + next_tokens = _sample_from_probs(probs, sampling_metadata) logprobs_tensors = ( None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) @@ -577,6 +878,8 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, logits=logits, + sampling_mask=sampling_mask, + logz_per_batch=logz_per_batch, ) return sampler_output @@ -654,6 +957,7 @@ def __init__(self, fd_config: FDConfig): self.spec_method = spec_config.method self.verify_strategy = spec_config.verify_strategy self.prefill_one_step_stop = fd_config.parallel_config.prefill_one_step_stop + self.num_speculative_tokens = spec_config.num_speculative_tokens # Accept policy from config (can be overridden by function parameters) self.config_accept_all = spec_config.accept_policy == "accept_all" @@ -679,54 +983,54 @@ def compute_logprobs( self, logits: paddle.Tensor, sampling_metadata: SamplingMetadata, + real_bsz: int = 0, ) -> paddle.Tensor: """compute logprobs""" share_inputs = sampling_metadata.share_inputs last_logits = logits - real_bsz = share_inputs["seq_lens_this_time"].shape[0] - batch_token_num = share_inputs["accept_num"][:real_bsz] + + # NOTE(huicongyao): temporarily used to provide a max_sized input, remove in the future + num_tokens = real_bsz * (self.num_speculative_tokens + 1) + padded_logits = paddle.zeros(shape=[num_tokens, last_logits.shape[1]], dtype=last_logits.dtype) + padded_logits[: logits.shape[0]] = last_logits + max_occupied_slots = share_inputs["seq_lens_this_time"].shape[0] + + batch_token_num = share_inputs["accept_num"][:max_occupied_slots] temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs if temp_scaled_logprobs is not None: - real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] - temperature = sampling_metadata.temperature[:real_bsz] - real_bsz_temp_scaled = ( - real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool") - ) - temperature = temperature.squeeze(1).repeat_interleave(batch_token_num) + real_bsz_temp_scaled = temp_scaled_logprobs[:max_occupied_slots] + temperature = sampling_metadata.temperature[:max_occupied_slots] + real_bsz_temp_scaled = build_sampling_params_logprob(real_bsz_temp_scaled, batch_token_num, num_tokens) + temperature = build_sampling_params_logprob(temperature, batch_token_num, num_tokens) temp_temperature = paddle.where( real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) ).unsqueeze(1) - last_logits = last_logits / temp_temperature + padded_logits = padded_logits / temp_temperature - last_logprobs = F.log_softmax(last_logits, axis=-1) + last_logprobs = F.log_softmax(padded_logits, axis=-1) top_p_logprob = None top_p_token_mask = None - if ( top_p_normalized_logprobs is not None and share_inputs is not None and sampling_metadata.top_p_normalized_logprobs_flag ): - real_token_top_p = ( - sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) - ) - top_p_normalized_logprobs = ( - top_p_normalized_logprobs[:real_bsz] - .astype("int32") - .squeeze(1) - .repeat_interleave(batch_token_num) - .astype("bool") - .unsqueeze(1) - ) + real_token_top_p = build_sampling_params_logprob( + sampling_metadata.top_p[:max_occupied_slots].squeeze(1), batch_token_num, num_tokens + ).unsqueeze(1) + top_p_normalized_logprobs = build_sampling_params_logprob( + top_p_normalized_logprobs[:max_occupied_slots].squeeze(1), batch_token_num, num_tokens + ).unsqueeze(1) top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) - if top_p_token_mask.any(): - probs = F.softmax(last_logits, axis=-1) - probs = top_p_normalize_probs_paddle(probs, real_token_top_p) - top_p_logprob = paddle.log(probs) + + probs = F.softmax(padded_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_token_top_p) + top_p_logprob = paddle.log(probs) if top_p_logprob is not None: last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) + return last_logprobs def gather_logprobs( @@ -769,7 +1073,18 @@ def gather_logprobs( indices = token_ids top_logprobs = token_logprobs - return LogprobsTensors(indices, top_logprobs, token_ranks) + if current_platform.is_cuda(): + indices_cpu = paddle.empty_like(indices, device="cpu").pin_memory() + top_logprobs_cpu = paddle.empty_like(top_logprobs, device="cpu").pin_memory() + token_ranks_cpu = paddle.empty_like(token_ranks, device="cpu").pin_memory() + indices_cpu.copy_(indices, False) + top_logprobs_cpu.copy_(top_logprobs, False) + token_ranks_cpu.copy_(token_ranks, False) + else: + indices_cpu = indices.cpu() + top_logprobs_cpu = top_logprobs.cpu() + token_ranks_cpu = token_ranks.cpu() + return LogprobsTensors(indices_cpu, top_logprobs_cpu, token_ranks_cpu) def _verify_and_sample( self, @@ -782,6 +1097,7 @@ def _verify_and_sample( increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, + topp_seed: Optional[paddle.Tensor] = None, ) -> SamplerOutput: """ Verify draft tokens against target model output and produce final samples. @@ -813,19 +1129,22 @@ def _verify_and_sample( target_tokens, candidate_ids, candidate_scores, candidate_lens = None, None, None, None if self.verify_strategy == VerifyStrategy.TARGET_MATCH: - # Only TARGET_MATCH needs stochastic sampling - top_p, top_k, topp_seed = build_sampling_params( - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.seed, - share_inputs["seq_lens_this_time"], - share_inputs["cu_seqlens_q_output"], - token_num_output_cpu, - increment_value, - ) - _, target_tokens = top_k_top_p_sampling( - probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed - ) + if FD_SAMPLING_CLASS.lower() == "triton": + target_tokens = _random_sample(probs, topp_seed=topp_seed) + else: + # Only TARGET_MATCH needs stochastic sampling + top_p, top_k, topp_seed = build_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + share_inputs["seq_lens_this_time"], + share_inputs["cu_seqlens_q_output"], + token_num_output_cpu, + increment_value, + ) + target_tokens = _sample_from_probs( + probs, sampling_metadata, top_p=top_p, top_k=top_k, topp_seed=topp_seed + ) elif self.verify_strategy == VerifyStrategy.GREEDY: # GREEDY: deterministic argmax in target_tokens, no candidates needed target_tokens = paddle.argmax(probs, axis=-1) @@ -890,6 +1209,7 @@ def _normal_sample( probs: paddle.Tensor, sampling_metadata: SamplingMetadata, share_inputs: List[paddle.Tensor], + topp_seed: Optional[paddle.Tensor] = None, ) -> SamplerOutput: """ Normal sampling without draft token verification. @@ -911,13 +1231,16 @@ def _normal_sample( probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) # Sample tokens - _, next_tokens = top_k_top_p_sampling( - probs, - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.top_k_list, - topp_seed=sampling_metadata.seed, - ) + if FD_SAMPLING_CLASS.lower() == "triton": + next_tokens = _random_sample(probs, topp_seed=topp_seed) + else: + next_tokens = _sample_from_probs( + probs, + sampling_metadata, + top_p=sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + topp_seed=sampling_metadata.seed, + ) # Scatter sampled tokens into accept_tokens using cu_seqlens_q_output to # correctly handle mixed prefill+decode batches where token index != batch index. @@ -946,6 +1269,7 @@ def forward_cuda( increment_value: int, accept_all_drafts: bool = False, reject_all_drafts: bool = False, + real_bsz: int = 0, ) -> SamplerOutput: """ Forward pass for speculative sampling. @@ -1010,12 +1334,32 @@ def forward_cuda( self.line_break_id, ) + logits_ori = None + topp_seed = None + if FD_SAMPLING_CLASS.lower() == "triton": + logits_ori = logits.clone() + top_p, top_k, topp_seed = build_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + share_inputs["seq_lens_this_time"], + share_inputs["cu_seqlens_q_output"], + token_num_output_cpu, + increment_value, + ) + logits = _apply_triton_top_k_top_p( + logits, + top_p, + top_k=top_k, + top_k_list=sampling_metadata.top_k_list, + ) + probs = F.softmax(logits) # Route based on spec_method is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE if is_naive: - sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs) + sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs, topp_seed=topp_seed) else: sampler_output = self._verify_and_sample( logits, @@ -1027,21 +1371,59 @@ def forward_cuda( increment_value, accept_all_drafts, reject_all_drafts, + topp_seed=topp_seed, ) + keep_sampling_mask = sampling_metadata.keep_sampling_mask # Build logprobs via unified path (outside of sampling logic) - if sampling_metadata.max_num_logprobs is not None: - logprobs_tensors, cu_batch_token_offset = build_output_logprobs( - logits, + if sampling_metadata.max_num_logprobs is not None or keep_sampling_mask: + logprobs_tensors, cu_batch_token_offset, target_logits = build_output_logprobs( + logits if logits_ori is None else logits_ori, sampling_metadata, share_inputs, is_naive=is_naive, logprobs_mode=self.logprobs_mode, compute_logprobs_fn=self.compute_logprobs, + real_bsz=real_bsz, ) sampler_output.logprobs_tensors = logprobs_tensors if cu_batch_token_offset is not None: - sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu() + cu_batch_token_offset_cpu = paddle.empty_like(cu_batch_token_offset, device="cpu").pin_memory() + cu_batch_token_offset_cpu.copy_(cu_batch_token_offset, False) + sampler_output.cu_batch_token_offset = cu_batch_token_offset_cpu + if keep_sampling_mask: + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + accept_nums = share_inputs["accept_num"][:real_bsz].reshape([-1]) + target_logits = target_logits[: accept_nums.sum()] + # Derive target probs from already-extracted target_logits; avoids a second kernel call. + target_probs = F.softmax(target_logits, axis=-1) + accept_top_p, accept_top_k, _ = build_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + share_inputs["seq_lens_this_time"], + share_inputs["cu_seqlens_q_output"], + token_num_output_cpu, + increment_value, + ) + + indices_window_gpu, mask_window_gpu, logz_per_batch, mask_bsz = _compute_sampling_mask( + target_probs, + accept_top_p, + top_k=accept_top_k, + top_k_list=sampling_metadata.top_k_list, + ) + # Async D2H copy with event + indices_window_cpu = paddle.empty_like( + indices_window_gpu, dtype=indices_window_gpu.dtype, device="cpu" + ).pin_memory() + mask_window_cpu = paddle.empty_like( + mask_window_gpu, dtype=mask_window_gpu.dtype, device="cpu" + ).pin_memory() + indices_window_cpu.copy_(indices_window_gpu, False) + mask_window_cpu.copy_(mask_window_gpu, False) + sampler_output.sampling_mask = (indices_window_cpu, mask_window_cpu, mask_bsz) + sampler_output.logz_per_batch = logz_per_batch return sampler_output def forward_xpu( @@ -1269,7 +1651,18 @@ def gather_logprobs( indices = token_ids top_logprobs = token_logprobs - return LogprobsTensors(indices, top_logprobs, token_ranks) + if current_platform.is_cuda(): + indices_cpu = paddle.empty_like(indices, device="cpu").pin_memory() + top_logprobs_cpu = paddle.empty_like(top_logprobs, device="cpu").pin_memory() + token_ranks_cpu = paddle.empty_like(token_ranks, device="cpu").pin_memory() + indices_cpu.copy_(indices, False) + top_logprobs_cpu.copy_(top_logprobs, False) + token_ranks_cpu.copy_(token_ranks, False) + else: + indices_cpu = indices.cpu() + top_logprobs_cpu = top_logprobs.cpu() + token_ranks_cpu = token_ranks.cpu() + return LogprobsTensors(indices_cpu, top_logprobs_cpu, token_ranks_cpu) def forward_cuda( self, diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 4e75ba1d90b..32a37f21967 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -45,6 +45,9 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm +from fastdeploy.model_executor.layers.quantization.fp8_utils import ( + per_token_group_quant_fp8, +) from fastdeploy.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, ) @@ -58,25 +61,14 @@ ) from fastdeploy.platforms import current_platform -if current_platform.is_cuda() or current_platform.is_maca(): - from fastdeploy.model_executor.ops.gpu import ( - get_position_ids_and_mask_encoder_batch, - ) - -from fastdeploy.model_executor.layers.quantization.fp8_utils import ( - per_token_group_quant_fp8, -) -from fastdeploy.platforms import current_platform - if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import ( cp_gather_indexer_k_quant_cache, indexer_k_quant_and_cache, + merge_prefill_decode_output, radix_topk_ragged_transform, ) - paddle.enable_compat(scope={"deep_gemm"}) - class DeepSeekV3MLP(nn.Layer): """ @@ -345,7 +337,6 @@ def forward( forward_meta: ForwardMeta, hidden_states: paddle.Tensor, position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ @@ -400,7 +391,6 @@ def forward( fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) - fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) fmha_out = fmha_out_prefill if need_do_decode: # max_dec_len_this_time @@ -435,7 +425,17 @@ def forward( ) if need_do_prefill: - fmha_out += fmha_out_decode + merge_prefill_decode_output( + fmha_out, + fmha_out_decode, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_attention_heads_tp, + self.v_head_dim, + 1, + ) else: fmha_out = fmha_out_decode @@ -443,33 +443,6 @@ def forward( return output -def compute_slot_mapping( - block_tables: paddle.Tensor, # [num_reqs, max_blocks_per_req] - positions: paddle.Tensor, # [num_tokens] 每个token的位置 - batch_id_per_token: paddle.Tensor, # [num_tokens] 每个token属于哪个请求 - block_size: int, -) -> paddle.Tensor: - """ - 计算 slot_mapping - - 公式: slot = block_id * block_size + offset_in_block - """ - # 1. 计算每个 token 对应的 block 索引 - block_idx = positions // block_size # [num_tokens] - - # 2. 从 block_tables 中查表获取 block_id - # block_tables[batch_id_per_token, block_idx] - block_ids = block_tables[batch_id_per_token, block_idx] # [num_tokens] - - # 3. 计算在 block 内的偏移 - block_offset = positions % block_size # [num_tokens] - - # 4. 计算 slot_mapping - slot_mapping = block_ids * block_size + block_offset - - return slot_mapping.cast(paddle.int64) - - import triton import triton.language as tl @@ -653,19 +626,14 @@ def forward( weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.index_n_heads**-0.5 weights = weights.squeeze(-1) - slot_mapping = compute_slot_mapping( - forward_meta.block_tables, - forward_meta.position_ids, - forward_meta.batch_id_per_token, - 64, - ) - indexer_top_k = paddle.full([q_fp8.shape[0], self.index_topk], -1, dtype="int32") # indexer write_cache - indexer_k_quant_and_cache(k, self.indexer_cache, slot_mapping, self.quant_block_size, self.scale_fmt) + indexer_k_quant_and_cache( + k, self.indexer_cache, forward_meta.slot_mapping, self.quant_block_size, self.scale_fmt + ) - import deep_gemm + from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm if forward_meta.max_len_tensor_cpu[1]: @@ -927,7 +895,6 @@ def forward( forward_meta: ForwardMeta, hidden_states: paddle.Tensor, position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) @@ -1045,7 +1012,6 @@ def forward( hidden_states: paddle.Tensor, residual: paddle.Tensor, position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ if hidden_states.shape[0] > 0: @@ -1053,7 +1019,7 @@ def forward( hidden_states, residual_input=residual, forward_meta=forward_meta ) - hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch) + hidden_states = self.self_attn(forward_meta, hidden_states, position_ids) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) else: @@ -1109,7 +1075,6 @@ def forward( ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) @@ -1120,8 +1085,7 @@ def forward( forward_meta, hidden_states, residual, - position_ids, - mask_encoder_batch, + position_ids.cast(paddle.int32), ) out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0] @@ -1156,12 +1120,6 @@ def __init__(self, fd_config: FDConfig): num_embeddings=fd_config.model_config.vocab_size, prefix="lm_head", ) - self.position_ids_buffer = paddle.empty( - [fd_config.scheduler_config.max_num_batched_tokens], dtype=paddle.int32 - ) - self.mask_encoder_batch_buffer = paddle.empty( - [fd_config.scheduler_config.max_num_batched_tokens, 1], dtype=paddle.int32 - ) @classmethod def name(cls): @@ -1258,25 +1216,6 @@ def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta logits[:, self.ori_vocab_size :] = -float("inf") return logits - def pre_process(self, forward_meta): - """ """ - seq_lens_encoder = forward_meta.seq_lens_encoder - seq_lens_decoder = forward_meta.seq_lens_decoder - seq_lens_this_time = forward_meta.seq_lens_this_time - - current_total_tokens = forward_meta.ids_remove_padding.shape[0] - position_ids = self.position_ids_buffer[:current_total_tokens] - mask_encoder_batch = self.mask_encoder_batch_buffer[:current_total_tokens] - - get_position_ids_and_mask_encoder_batch( - seq_lens_encoder, - seq_lens_decoder, - seq_lens_this_time, - position_ids, - mask_encoder_batch, - ) - return position_ids, mask_encoder_batch - def empty_input_forward(self, forward_meta): """ empty_input_forward @@ -1297,18 +1236,16 @@ def forward( forward_meta: ForwardMeta, ): ids_remove_padding = inputs["ids_remove_padding"] - forward_meta.position_ids, mask_encoder_batch = self.pre_process(forward_meta) hidden_states = self.model( ids_remove_padding=ids_remove_padding, forward_meta=forward_meta, position_ids=forward_meta.position_ids, - mask_encoder_batch=mask_encoder_batch, ) return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class DeepSeekV3PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 4cc4306de5f..bf8b3d93481 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -701,9 +701,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) + self.ernie.clear_graph_opt_backend(fd_config=self.fd_config) @ModelRegistry.register_model_class( diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index f4d70108e4b..b6fa97ab0ba 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -829,9 +829,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.ernie.clear_grpah_opt_backend(fd_config=self.fd_config) + self.ernie.clear_graph_opt_backend(fd_config=self.fd_config) class Ernie4_5_VLPretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index b32ebb2ced9..0df0f7103cb 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -25,6 +25,7 @@ from paddleformers.transformers import PretrainedModel from paddleformers.utils.log import logger +import fastdeploy from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -63,6 +64,14 @@ def __init__( reduce_results: bool = True, ) -> None: super().__init__() + self.expert_parallel_size = fd_config.parallel_config.expert_parallel_size + self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size + self.use_tp = self.tensor_parallel_size > 1 + self.use_ep = self.expert_parallel_size > 1 + self.enable_all_reduce_fusion = fd_config.parallel_config.enable_flashinfer_allreduce_fusion and ( + self.use_tp and not self.use_ep + ) + # shared experts not split when use_sequence_parallel_moe in ep + tp if ( fd_config.parallel_config.use_sequence_parallel_moe @@ -100,6 +109,7 @@ def __init__( output_size=fd_config.model_config.hidden_size, with_bias=False, reduce_results=reduce_results, + enable_all_reduce_fusion=self.enable_all_reduce_fusion, ) self.act_fn = SiluAndMul( @@ -129,10 +139,12 @@ def __init__( self.tensor_parallel_size = fd_config.parallel_config.tensor_parallel_size self.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank self.tp_group = fd_config.parallel_config.tp_group - self.use_ep = self.expert_parallel_size > 1 self.use_tp = self.tensor_parallel_size > 1 - + self.last_layer_id = fd_config.model_config.num_hidden_layers - 1 + self.enable_all_reduce_fusion = ( + fd_config.parallel_config.enable_flashinfer_allreduce_fusion and layer_id != self.last_layer_id + ) self.n_routed_experts: int = fd_config.model_config.n_routed_experts self.n_shared_experts: int = fd_config.model_config.n_shared_experts @@ -182,6 +194,7 @@ def __init__( layer_idx=layer_id, gate_correction_bias=self.gate.e_score_correction_bias, weight_key_map=weight_key_map, + topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, ) if self.n_shared_experts > 0: @@ -199,8 +212,10 @@ def forward(self, x, forward_meta: ForwardMeta = None): if self.n_shared_experts > 0: out = out + self.shared_experts(x) if self.merge_ffn_tp: - # Both branches produced partial sums; combine first, then single all-reduce. - out = tensor_model_parallel_all_reduce(out, self.tp_group) + need_tp_all_reduce_fusion = self.enable_all_reduce_fusion and out.shape[0] <= 2048 + if not need_tp_all_reduce_fusion: + # Both branches produced partial sums; combine first, then single all-reduce. + out = tensor_model_parallel_all_reduce(out, self.tp_group) return out @@ -228,6 +243,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None input_size=fd_config.model_config.num_attention_heads * fd_config.model_config.head_dim, output_size=fd_config.model_config.hidden_size, layer_id=layer_id, + enable_all_reduce_fusion=fd_config.parallel_config.enable_flashinfer_allreduce_fusion, ) self.attn = Attention( @@ -265,6 +281,14 @@ def forward( return output +def rms_norm_func(x, weight, eps): + rms_norm_out = paddle.nn.functional.rms_norm(x, x.shape[-1:], weight, eps) + if isinstance(rms_norm_out, (tuple, list)): + return rms_norm_out[0].astype(weight.dtype) + else: + return rms_norm_out.astype(weight.dtype) + + class Glm4MoeDecoderLayer(nn.Layer): """ """ @@ -318,8 +342,11 @@ def forward( residual: paddle.Tensor = None, ): """ """ + + proxy_rmsnorm = rms_norm_func if fastdeploy.envs.FD_USE_PHI_RMSNORM else None + hidden_states, residual = self.input_layernorm( - hidden_states, residual_input=residual, forward_meta=forward_meta + hidden_states, residual_input=residual, forward_meta=forward_meta, proxy_rmsnorm=proxy_rmsnorm ) hidden_states = self.self_attn( @@ -328,7 +355,7 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual, proxy_rmsnorm=proxy_rmsnorm) hidden_states = self.mlp(hidden_states, forward_meta) @@ -550,9 +577,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Glm4MoePretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/glm4_mtp.py b/fastdeploy/model_executor/models/glm4_mtp.py index c28023202d2..c700ea442c5 100644 --- a/fastdeploy/model_executor/models/glm4_mtp.py +++ b/fastdeploy/model_executor/models/glm4_mtp.py @@ -369,3 +369,7 @@ def forward( ) return hidden_states + + def clear_graph_opt_backend(self): + """Clear graph optimization backend, the captured cuda graph will be cleaned""" + self.model.clear_graph_opt_backend(fd_config=self.fd_config) diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 1bca09265ee..1d0ce349bf2 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -417,9 +417,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.qwen2.clear_grpah_opt_backend(fd_config=self.fd_config) + self.qwen2.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen2PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index ebbf4f5aed0..b0bcf9d5883 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -341,9 +341,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3PretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py index a4d3f1579c3..3f2a6904248 100644 --- a/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py +++ b/fastdeploy/model_executor/models/qwen3_vl/qwen3_vl.py @@ -382,9 +382,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3VLPretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 74ca37ab695..95adc7ad0eb 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -453,9 +453,9 @@ def forward( return hidden_states - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """Clear graph optimization backend, the captured cuda graph will be cleaned""" - self.model.clear_grpah_opt_backend(fd_config=self.fd_config) + self.model.clear_graph_opt_backend(fd_config=self.fd_config) class Qwen3MoePretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 0fc6bfde5d0..a9309f5a3af 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -22,9 +22,14 @@ from fastdeploy import envs from fastdeploy.config import SpeculativeConfig +from fastdeploy.model_executor.ops.gpu import ( + mtp_save_first_token, + mtp_save_first_token_with_topk, +) from fastdeploy.platforms import current_platform from fastdeploy.worker.input_batch import ( InputBatch, + ProposerInputBatch, recover_batch_index_for_output, recover_batch_index_for_sampler_output, ) @@ -109,43 +114,44 @@ from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( RoutingReplayManager, ) +from fastdeploy.model_executor.layers.sample.logprobs import ( + logprobs_renormalize_with_logz, +) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.layers.sample.sampler import _extract_sparse_indices from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, SamplerOutput DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" -if current_platform.is_cuda(): - def async_set_value(tgt, src): - if isinstance(src, (int, float, bool)): - src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype) - elif isinstance(src, (list, np.array)): - dtype_str = str(tgt.dtype).split(".")[1] - if isinstance(src, list): - src = np.array(src, dtype=dtype_str if dtype_str != "bfloat16" else "float32") +def async_set_value(tgt, src): + if isinstance(src, (int, float, bool)): + src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype) + elif isinstance(src, (list, np.ndarray)): + dtype_str = str(tgt.dtype).split(".")[1] + if isinstance(src, list): + src = np.array(src, dtype=dtype_str if dtype_str != "bfloat16" else "float32") + if current_platform.is_cuda(): if str(src.dtype) != dtype_str: srt_tensor = paddle.empty(tgt.shape, dtype=str(src.dtype)) src = custom_numpy_to_tensor(src, srt_tensor) else: return custom_numpy_to_tensor(src, tgt) - elif isinstance(src, paddle.Tensor): - pass else: - raise ValueError("async_set_value unsupported src type: {}".format(type(src))) - if src.shape != tgt.shape: - src = src.reshape(tgt.shape) - if src.dtype != tgt.dtype: - src = src.cast(tgt.dtype) - if src.place != tgt.place: - src = src.to(tgt.place) - tgt.copy_(src, blocking=False) - -else: - - def async_set_value(*args, **kwargs): - raise RuntimeError("async_set_value is only available on CUDA") + src = paddle.to_tensor(src, dtype=tgt.dtype) + elif isinstance(src, paddle.Tensor): + pass + else: + raise ValueError("async_set_value unsupported src type: {}".format(type(src))) + if src.shape != tgt.shape: + src = src.reshape(tgt.shape) + if src.dtype != tgt.dtype: + src = src.cast(tgt.dtype) + if src.place != tgt.place: + src = src.to(tgt.place) + tgt.copy_(src, blocking=False) def pre_process( @@ -216,6 +222,7 @@ def _build_stream_transfer_data( pooler_outputs: List[PoolingSequenceGroupOutput] = None, logprobs: Optional[LogprobsTensors] = None, prompt_logprobs_list: Optional[LogprobsTensors] = None, + sampling_mask: Optional[List[np.ndarray]] = None, ): """Split output_tokens and output""" @@ -225,6 +232,8 @@ def _build_stream_transfer_data( output_tokens = output_tokens.numpy().reshape([-1]) output_tokens_lists = np.split(output_tokens, output_tokens.shape[0]) + sampling_mask_list = sampling_mask + for bid, output_token_per_sample in enumerate(output_tokens_lists): stream_transfer_data = StreamTransferData( decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid @@ -233,6 +242,8 @@ def _build_stream_transfer_data( stream_transfer_data.logprobs = logprobs.slice_rows(bid, bid + 1) if prompt_logprobs_list: stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid] + if sampling_mask_list is not None: + stream_transfer_data.sampling_mask = sampling_mask_list[bid] stream_transfer_datas.append(stream_transfer_data) elif pooler_outputs is not None: for bid, pooler_output in enumerate(pooler_outputs): @@ -325,18 +336,11 @@ def post_process_normal( # Routing replay if routing_replay_manager is not None: - # Update host cache - slot_mapping = routing_replay_manager.compute_slot_mapping( - positions=routing_replay_manager.pending_update_positions - ) - routing_replay_manager.update_host_cache( - positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping - ) - - # Put routing of finished requests to store - finished_batch_ids = paddle.flatten(paddle.isin(sampler_output.sampled_token_ids, model_output.eos_token_id)) - context_lens = model_output.seq_lens_decoder + model_output.seq_lens_encoder - routing_replay_manager.put_finished_batch(finished_batch_ids=finished_batch_ids, seq_lens_decoder=context_lens) + slot_mapping_gpu = share_inputs["slot_mapping_buffer"] + position_ids_gpu = share_inputs.get("position_ids_buffer") + num_tokens = int(share_inputs["ids_remove_padding"].shape[0]) + if routing_replay_manager.tp_rank == 0: + routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu, position_ids_gpu) # 2. Update the input buffer of the model with paddle.framework._no_check_dy2st_diff(): @@ -368,6 +372,9 @@ def post_process_normal( model_output.is_block_step, ) + # logprobs renormalization with logz is deferred to save_output, + # so that async D2H of logz_per_batch has more time to complete. + def save_output_normal( model_output: ModelOutputData, @@ -375,7 +382,23 @@ def save_output_normal( share_inputs: Dict[str, paddle.Tensor], async_output_queue: queue.Queue = None, save_each_rank: bool = False, + sampling_mask_async_queue: Optional[queue.Queue] = None, ): + # Extract sparse indices from pinned CPU buffers + if sampler_output.sampling_mask is not None: + indices_window_cpu, mask_window_cpu, mask_bsz = sampler_output.sampling_mask + sampler_output.sampling_mask = _extract_sparse_indices( + indices_window_cpu.numpy(), mask_window_cpu.numpy(), mask_bsz + ) + + # Renormalize logprobs with logz (deferred from post_process for better overlap). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( + sampler_output.logprobs_tensors.logprobs, + sampler_output.logz_per_batch, + sampler_output.logprobs_tensors, + ) + # Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if envs.FD_USE_GET_SAVE_OUTPUT_V1: @@ -393,6 +416,7 @@ def save_output_normal( recover_share_inputs_map["sampled_token_ids"], logprobs=sampler_output.logprobs_tensors, prompt_logprobs_list=model_output.prompt_logprobs_list, + sampling_mask=sampler_output.sampling_mask, ) async_output_queue.put(output) else: @@ -429,10 +453,17 @@ def save_output_normal( recover_share_inputs_map["last_preempted_idx"], model_output.mp_rank, ) + # Send sampling_mask via ZMQ side-channel when enabled (async via background thread). + if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: + # sampling_mask already resolved at function entry. + assert ( + sampling_mask_async_queue is not None + ), "sampling_mask_async_queue must not be None when sampling_mask is enabled" + sampling_mask_async_queue.put((sampler_output.sampling_mask, None)) share_inputs["last_preempted_idx"][:] = 0 -def post_process_specualate( +def post_process_speculate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: InputBatch, @@ -475,27 +506,11 @@ def post_process_specualate( # Routing replay if routing_replay_manager is not None: - # Update host cache - slot_mapping = routing_replay_manager.compute_slot_mapping( - positions=routing_replay_manager.pending_update_positions - ) - routing_replay_manager.update_host_cache( - positions=routing_replay_manager.pending_update_positions, slot_mapping=slot_mapping - ) - - # Put routing of finished requests to store - last_accept_token = paddle.full_like(model_output.accept_tokens, -1) - col_indices = paddle.arange(model_output.accept_tokens.shape[1], dtype=model_output.accept_num.dtype) - mask = col_indices < paddle.unsqueeze(model_output.accept_num, 1) - last_accept_token[mask] = model_output.accept_tokens[mask] - eos_tokens_flat = model_output.eos_token_id.flatten() - isin_mask = paddle.isin(last_accept_token, eos_tokens_flat) - finished_batch_ids = isin_mask.any(axis=-1) - context_lens = model_output.seq_lens_encoder + model_output.seq_lens_decoder - routing_replay_manager.put_finished_batch( - finished_batch_ids=finished_batch_ids, - seq_lens_decoder=context_lens, - ) + slot_mapping_gpu = share_inputs["slot_mapping_buffer"] + position_ids_gpu = share_inputs.get("position_ids_buffer") + num_tokens = int(share_inputs["ids_remove_padding"].shape[0]) + if routing_replay_manager.tp_rank == 0: + routing_replay_manager.prepare_pending_save(num_tokens, slot_mapping_gpu, position_ids_gpu) # Unified state update: merges speculate_update + speculate_set_value_by_flags_and_idx # into a single kernel launch. Handles EOS detection, max_dec_len truncation, step_idx @@ -520,15 +535,109 @@ def post_process_specualate( model_output.max_dec_len, # max_dec_len ) + # logprobs renormalization with logz is deferred to save_output, + # so that async D2H of logz_per_batch has more time to complete. + -def save_output_specualate( +def save_output_speculate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: InputBatch, + local_rank: int, + tensor_parallel_rank: int, save_each_rank: bool = False, - skip_save_output: bool = False, + sampling_mask_async_queue: Optional[queue.Queue] = None, + is_mtp_prefill: bool = False, + proposer_share_inputs: Optional[ProposerInputBatch] = None, ): - if not skip_save_output: + # Resolve deferred async D2H: sync event once at the top so all paths below + # can safely read sampling_mask and logz_per_batch. + mask_bsz = None + if sampler_output.sampling_mask is not None: + indices_window_cpu, mask_window_cpu, mask_bsz = sampler_output.sampling_mask + sampler_output.sampling_mask = _extract_sparse_indices( + indices_window_cpu.numpy(), mask_window_cpu.numpy(), mask_bsz + ) + + # Renormalize logprobs with logz (deferred from post_process for better overlap). + if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: + assert mask_bsz is not None + sampler_output.logprobs_tensors = LogprobsTensors( + logprob_token_ids=sampler_output.logprobs_tensors.logprob_token_ids[:mask_bsz], + logprobs=sampler_output.logprobs_tensors.logprobs[:mask_bsz], + selected_token_ranks=sampler_output.logprobs_tensors.selected_token_ranks[:mask_bsz], + ) + sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( + sampler_output.logprobs_tensors.logprobs, + sampler_output.logz_per_batch, + sampler_output.logprobs_tensors, + ) + + if is_mtp_prefill: + assert proposer_share_inputs is not None + if tensor_parallel_rank == 0: + skip_chunk_prefill = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + if sampler_output.logprobs_tensors is None: + recover_proposer_share_inputs_map = recover_batch_index_for_output( + proposer_share_inputs, + proposer_share_inputs.index_to_batch_id, + proposer_share_inputs.enable_pd_reorder, + [ + "base_model_draft_tokens", + "seq_lens_decoder", + "prompt_lens", + "step_idx", + ], + ) + mtp_save_first_token( + recover_proposer_share_inputs_map["base_model_draft_tokens"], + proposer_share_inputs["not_need_stop"], + recover_proposer_share_inputs_map["seq_lens_decoder"], + recover_proposer_share_inputs_map["prompt_lens"], + recover_proposer_share_inputs_map["step_idx"], + local_rank, + save_each_rank, + skip_chunk_prefill, + ) + else: + recover_share_inputs_map = recover_batch_index_for_output( + share_inputs, + model_output.index_to_batch_id, + model_output.enable_pd_reorder, + [ + "sampled_token_ids", + "accept_tokens_cpu", + "accept_num_cpu", + "seq_lens_decoder_cpu", + "prompt_lens_cpu", + "last_preempted_idx", + ], + ) + recover_batch_index_for_sampler_output( + sampler_output, model_output.index_to_batch_id, model_output.enable_pd_reorder + ) + recover_proposer_share_inputs_map = recover_batch_index_for_output( + proposer_share_inputs, + proposer_share_inputs.index_to_batch_id, + proposer_share_inputs.enable_pd_reorder, + ["base_model_draft_tokens"], + ) + mtp_save_first_token_with_topk( + recover_proposer_share_inputs_map["base_model_draft_tokens"], + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + recover_share_inputs_map["accept_num_cpu"], + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + recover_share_inputs_map["seq_lens_decoder_cpu"], + recover_share_inputs_map["prompt_lens_cpu"], + recover_share_inputs_map["last_preempted_idx"], + 3, # mtype + model_output.mp_rank, + save_each_rank, + ) + else: if sampler_output.logprobs_tensors is None: recover_share_inputs = recover_batch_index_for_output( share_inputs, @@ -585,6 +694,16 @@ def save_output_specualate( model_output.mp_rank, save_each_rank, ) + # Send sampling_mask via ZMQ side-channel when enabled (async via background thread). + if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: + # sampling_mask already resolved at function entry. + # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). + real_bsz = model_output.accept_num.shape[0] + accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() + assert ( + sampling_mask_async_queue is not None + ), "sampling_mask_async_queue must not be None when sampling_mask is enabled" + sampling_mask_async_queue.put((sampler_output.sampling_mask, accept_nums)) share_inputs["last_preempted_idx"][:] = 0 @@ -618,7 +737,7 @@ def post_process( ) else: if speculative_decoding: - post_process_specualate( + post_process_speculate( sampler_or_pooler_output, model_output, share_inputs, diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index e63603047be..952b82b2a60 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -14,6 +14,8 @@ # limitations under the License. """ +import importlib +import importlib.util import os import re from collections.abc import Mapping @@ -129,6 +131,35 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1): return weight_or_paramter +def _is_gdr_checkpoint_transfer_dynamic_load_config(fd_config: FDConfig) -> bool: + load_config = fd_config.load_config + if not load_config.dynamic_load_weight: + return False + return envs.FD_USE_GDR_CHECKPOINT_TRANSFER + + +def _copy_gdr_checkpoint_transfer_transposed_weight_attrs(src, dst): + attr_names = ( + "weight_loader", + "output_dim", + "weight_need_transpose", + "is_distributed", + "split_axis", + "tp_row_bias", + ) + for name in attr_names: + if hasattr(src, name): + setattr(dst, name, getattr(src, name)) + if hasattr(src, "output_dim") and src.output_dim is not None: + dst.output_dim = not src.output_dim + dst.weight_need_transpose = not getattr(src, "weight_need_transpose", False) + if hasattr(src, "split_axis"): + if len(src.shape) == 2 and src.split_axis in (0, 1): + dst.split_axis = 1 - src.split_axis + elif len(src.shape) == 3 and src.split_axis in (1, 2): + dst.split_axis = 3 - src.split_axis + + def process_weight_transpose(layer, weight_name): weight = getattr(layer, weight_name) if len(weight.shape) == 2: @@ -141,6 +172,8 @@ def process_weight_transpose(layer, weight_name): default_initializer=paddle.nn.initializer.Constant(0), is_bias=False, ) + if _is_gdr_checkpoint_transfer_dynamic_load_config(layer.fd_config): + _copy_gdr_checkpoint_transfer_transposed_weight_attrs(weight, weight_tmp) if layer.fd_config.load_config.dynamic_load_weight or getattr(layer.fd_config.model_config, "enable_cache", False): free_tensor(weight) setattr(layer, weight_name, weight_tmp) @@ -346,6 +379,8 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None): f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" ) loaded_weight = get_tensor(loaded_weight) + if not param._is_initialized(): + param.initialize() param.copy_(loaded_weight, False) return fn @@ -553,6 +588,34 @@ def fn(loaded_weight_name, is_moe): return fn +def try_import(modules, name=None, fail_msg=None): + """ + try_import + """ + if not isinstance(modules, (list, tuple)): + modules = [modules] + + for m in modules: + assert isinstance(m, str), m + try: + m = importlib.import_module(m) + except ImportError: + m = None + + if m is not None: + if name is None: + return m + elif hasattr(m, name): + return getattr(m, name) + + if fail_msg is not None: + logger.warning(fail_msg) + + +def has_flashinfer(): + return importlib.util.find_spec("flashinfer") is not None + + @cache def get_sm_version(): if paddle.cuda.is_available(): diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 9e32ea34876..9232989dd48 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -55,6 +55,29 @@ DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" +def async_set_value(tgt, src): + if isinstance(src, (int, float, bool)): + src = paddle.full(tgt.shape, fill_value=src, dtype=tgt.dtype) + elif isinstance(src, (list, np.ndarray)): + dtype_str = str(tgt.dtype).split(".")[1] + np_dtype = dtype_str if dtype_str != "bfloat16" else "float32" + if isinstance(src, list): + src = np.array(src, dtype=np_dtype) + # TODO: support async_numpy_to_tensor + src = paddle.to_tensor(src, dtype=tgt.dtype) + elif isinstance(src, paddle.Tensor): + pass + else: + raise ValueError("async_set_value unsupported src type: {}".format(type(src))) + if src.shape != tgt.shape: + src = src.reshape(tgt.shape) + if src.dtype != tgt.dtype: + src = src.cast(tgt.dtype) + if src.place != tgt.place: + src = src.to(tgt.place) + tgt.copy_(src, blocking=False) + + def _build_stream_transfer_data( output_tokens: paddle.Tensor, pooler_outputs: List = None, @@ -381,7 +404,7 @@ def xpu_post_process_normal( share_inputs["preempted_idx"][:] = 0 -def xpu_post_process_specualate( +def xpu_post_process_speculate( sampler_output: SamplerOutput, model_output: ModelOutputData, share_inputs: Dict[str, paddle.Tensor], diff --git a/fastdeploy/output/stream_transfer_data.py b/fastdeploy/output/stream_transfer_data.py index b32e01c954f..dce21bb5963 100644 --- a/fastdeploy/output/stream_transfer_data.py +++ b/fastdeploy/output/stream_transfer_data.py @@ -46,3 +46,7 @@ class StreamTransferData: accept_num: Optional[np.array] = None # [num_reqs, hidden_size] pooler_output: Optional[np.array] = None + # 1-D int32 numpy array of vocab indices retained by top_p/top_k for + # this request. Sparse format: only retained positions, not a dense + # vocab-sized bool mask. + sampling_mask: Optional[np.array] = None diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 1ab0b48f350..9e37cc52f11 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -37,6 +37,7 @@ Request, RequestMetrics, RequestOutput, + RequestStatus, SpeculateMetrics, ) from fastdeploy.inter_communicator import ZmqIpcServer @@ -68,6 +69,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.cached_generated_tokens = cached_generated_tokens self.resource_manager = None self.scheduler_metrics_logger = None + self._benchmark_logger = None self.engine_worker_queue = engine_worker_queue self.tokens_counter = Counter() self.split_connector = split_connector @@ -83,6 +85,14 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.speculative_decoding = self.cfg.speculative_config.method is not None self.use_logprobs = self.cfg.model_config.enable_logprob + self.use_sampling_mask = getattr(self.cfg.model_config, "enable_keep_sampling_mask", False) + if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.use_sampling_mask: + rank_id = self.cfg.parallel_config.local_data_parallel_id + port = self.cfg.parallel_config.engine_worker_queue_port[rank_id] + self.sampling_mask_zmq_server = ZmqIpcServer( + name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PULL + ) + llm_logger.info(f"create zmq sampling_mask_output_rank_{rank_id}_{port}") self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob if self.speculative_decoding: @@ -132,6 +142,65 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.health_lock = threading.Lock() self.engine_output_token_hang = False + # Routing replay: attach to SharedMemory routing_host_buffer (lazy init after profiling) + self.routing_host_view = None + self._routing_host_view_init_attempted = False + self.routing_cache_manager = None # Set by Engine after profiling for local/rdma store dispatch + + def _init_routing_host_view(self): + """Attach to SharedMemory routing_host_buffer created by Engine. Called lazily.""" + self._routing_host_view_init_attempted = True + if not self.cfg.routing_replay_config.enable_routing_replay: + return + try: + from fastdeploy.cache_manager.routing_cache_manager import ( + RoutingHostBufferView, + ) + + rrc = self.cfg.routing_replay_config + cache_config = self.cfg.cache_config + + dp_suffix = str(self.cfg.parallel_config.local_engine_worker_queue_port) + shm_name = f"routing_host_buffer.{dp_suffix}" + num_gpu_blocks = cache_config.total_block_num + max_num_kv_tokens = num_gpu_blocks * cache_config.block_size + shape = (max_num_kv_tokens, rrc.num_moe_layers, rrc.moe_top_k) + + self.routing_host_view = RoutingHostBufferView(shape=shape, dtype=rrc.routing_dtype, shm_name=shm_name) + self._routing_block_size = cache_config.block_size + llm_logger.info(f"[R3] TokenProcessor attached to RoutingHostBuffer: {shm_name}") + except FileNotFoundError: + llm_logger.warning("[R3] RoutingHostBuffer SharedMemory not found, routing gather disabled.") + except Exception as e: + llm_logger.warning(f"[R3] Failed to attach to RoutingHostBuffer: {e}") + + def _gather_routing_for_finished_request(self, task, seq_len: int): + """ + Gather complete routing data for a finished request from routing_host_buffer. + + Args: + task: Request task with block_tables + seq_len: Total sequence length + + Returns: + numpy array [seq_len, num_moe_layers, top_k] or None + """ + if self.routing_host_view is None and not self._routing_host_view_init_attempted: + self._init_routing_host_view() + if self.routing_host_view is None: + return None + + import math + + block_size = self._routing_block_size + block_ids = task.block_tables[: math.ceil(seq_len / block_size)] + positions = np.arange(seq_len) + block_indices = positions // block_size + offsets = positions % block_size + slot_mapping = np.array(block_ids)[block_indices] * block_size + offsets + + return self.routing_host_view.gather(slot_mapping) + def healthy(self): """ whether token processor is healthy @@ -168,6 +237,9 @@ def set_resource_manager(self, resource_manager): def set_scheduler_metrics_logger(self, scheduler_metrics_logger): self.scheduler_metrics_logger = scheduler_metrics_logger + def set_benchmark_logger(self, benchmark_logger): + self._benchmark_logger = benchmark_logger + def _is_decode_stage(self, task): if task is None: return False @@ -264,8 +336,8 @@ def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: llm_logger.info(f"{self.resource_manager.info()}") if self.cfg.speculative_config.method: self._compute_speculative_status() - if not is_prefill: - self._record_completion_metrics(task, current_time) + self._record_completion_metrics(task, current_time) + self._finalize_routing(task_id, task, result, is_prefill) self._recycle_resources(task_id, batch_id, task, result, is_prefill) break return result @@ -291,6 +363,7 @@ def _process_batch_output_use_zmq(self, receive_datas): ): llm_logger.info(f"start to recycle abort request_id {task_id}") self.resource_manager.recycle_abort_task(task_id) + self._put_abort_results(task) if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set and token_ids[-1] == PREEMPTED_TOKEN_ID @@ -329,6 +402,7 @@ def _process_batch_output_use_zmq(self, receive_datas): prompt_token_ids=task.prompt_token_ids, outputs=PoolingOutput(data=pooler_output), ) + self._finalize_routing(task_id, task, result, False) self._recycle_resources(task_id, i, task, result, False) batch_result.append(result) else: @@ -357,6 +431,8 @@ def _process_batch_output_use_zmq(self, receive_datas): result.prompt_logprobs = stream_data.prompt_logprobs except Exception as e: llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}") + if getattr(stream_data, "sampling_mask", None) is not None: + result.outputs.sampling_mask = stream_data.sampling_mask.tolist() if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages @@ -513,42 +589,128 @@ def postprocess(self, batch_result: List[RequestOutput], mtype=3): except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") + def _finalize_routing(self, task_id, task, result, is_prefill=False): + """ + Gather routing data before blocks are freed. + Must be called before _recycle_resources so that block_tables are still valid. + + - PD P node (is_prefill=True): gather prefill-only routing, attach to result for sending to D. + - Non-PD / D node (result.finished): gather full routing (prompt + output), + either attach to result ("response" mode) or dispatch to store ("local"/"rdma" mode). + """ + if not self.cfg.routing_replay_config.enable_routing_replay: + return + if result is None: + return + + try: + if is_prefill: + if result.error_code == 200: + seq_len = task.prompt_token_ids_len + routing_data = self._gather_routing_for_finished_request(task, seq_len) + if routing_data is not None: + result.routing_data = routing_data + elif result.finished: + store_type = self.cfg.routing_replay_config.routing_store_type + seq_len = ( + task.prompt_token_ids_len + len(task.output_token_ids) + if hasattr(task, "output_token_ids") + else task.prompt_token_ids_len + ) + if task.output_token_ids[-1] in task.eos_token_ids: + seq_len = seq_len - 1 # Ignore eos token + if store_type == "response": + routing_data = self._gather_routing_for_finished_request(task, seq_len) + if routing_data is not None: + result.routing_data = routing_data + elif self.routing_cache_manager is not None: + self.routing_cache_manager.on_request_finished( + request_id=task_id, + block_table=task.block_tables, + seq_len=seq_len, + ) + except Exception as e: + llm_logger.warning(f"[R3] Failed to finalize routing for {task_id}: {e}") + def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False): """ recycle resources """ if is_prefill: - start_time = time.time() - result.metrics.wait_for_sending_cache_time = time.time() - while True: - finished_task_ids = self.engine_worker_queue.get_finished_req() - if len(finished_task_ids) > 0: - for finished_task_id in finished_task_ids: - llm_logger.info(f"finished_task_id: {finished_task_id}") - self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] - if task_id in self.prefill_result_status: - if self.prefill_result_status[task_id] != "finished": - result.error_code = 400 - result.error_message = f"{task_id} failed to {self.prefill_result_status[task_id]}" - llm_logger.info( - f"wait for sending cache, request_id: {task_id}, cost seconds: {time.time()-start_time:.5f}" + if envs.FD_PD_TRANSFER_VIA_STORAGE: + # Storage pool mode: bypass CacheMessager entirely. + # At this point, all transformer layers are complete and KV cache is in GPU memory. + # Directly write cache to storage and send first token to D. + result.metrics.wait_for_sending_cache_time = time.time() + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) + if result.error_code == 200: + write_cache_start_time = time.time() + llm_logger.info(f"[PD Storage] P writing cache to storage (direct), request_id: {task_id}") + write_success = self.resource_manager.cache_manager.write_all_cache_to_storage( + task, include_output=False ) - result.metrics.send_request_output_to_decode_time = time.time() - self.split_connector.send_first_token(task.disaggregate_info, [result]) - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - self.resource_manager.finish_requests_async(task_id) + if not write_success: + result.error_code = 501 + result.error_msg = f"P instance failed to write cache to storage for request {task_id}" + llm_logger.error(f"[PD Storage] {result.error_msg}") else: - self.resource_manager.stop_flags[index] = True - self.resource_manager.tasks_list[index] = None - self.resource_manager._recycle_block_tables(task) - if task_id in self.resource_manager.req_dict: - del self.resource_manager.req_dict[task_id] - break + llm_logger.info( + f"[PD Storage] P finished writing cache to storage (direct), " + f"request_id: {task_id}, cost: {time.time()-write_cache_start_time:.5f}s" + ) + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) + result.metrics.send_request_output_to_decode_time = time.time() + self.split_connector.send_first_token(task.disaggregate_info, [result]) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager.finish_requests_async(task_id) else: - # TODO: Refine checking sending cache and do not keep waiting - if time.time() - start_time > 30: - llm_logger.warning(f"wait for sending cache, {task_id}") - time.sleep(0.002) + self.resource_manager.stop_flags[index] = True + self.resource_manager.tasks_list[index] = None + self.resource_manager._recycle_block_tables(task) + if task_id in self.resource_manager.req_dict: + del self.resource_manager.req_dict[task_id] + else: + # RDMA/IPC mode: poll CacheMessager for transfer completion + start_time = time.time() + result.metrics.wait_for_sending_cache_time = time.time() + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_START, task_id, getattr(task, "user", "")) + + while True: + finished_task_ids = self.engine_worker_queue.get_finished_req() + if len(finished_task_ids) > 0: + for finished_task_id in finished_task_ids: + llm_logger.info(f"finished_task_id: {finished_task_id}") + self.prefill_result_status[finished_task_id[0]] = finished_task_id[1] + if task_id in self.prefill_result_status: + if self.prefill_result_status[task_id] != "finished": + result.error_code = 501 + result.error_msg = ( + f"PD Error: prefill failed to send cache to decode, " + f"{task_id}, {self.prefill_result_status[task_id]}" + ) + self.prefill_result_status.pop(task_id) + llm_logger.info( + f"wait for sending cache, request_id: {task_id}, " + f"cost seconds: {time.time()-start_time:.5f}" + ) + trace_print(LoggingEventName.CHECK_CACHE_TRANSFER_END, task_id, getattr(task, "user", "")) + + result.metrics.send_request_output_to_decode_time = time.time() + self.split_connector.send_first_token(task.disaggregate_info, [result]) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + self.resource_manager.finish_requests_async(task_id) + else: + self.resource_manager.stop_flags[index] = True + self.resource_manager.tasks_list[index] = None + self.resource_manager._recycle_block_tables(task) + if task_id in self.resource_manager.req_dict: + del self.resource_manager.req_dict[task_id] + break + else: + # TODO: Refine checking sending cache and do not keep waiting + if time.time() - start_time > 30: + llm_logger.warning(f"wait for sending cache, {task_id}") + time.sleep(0.005) else: if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager.finish_requests_async(task_id) @@ -672,12 +834,15 @@ def _process_batch_draft_tokens(self, mtype, batch, accept_num, tokens, scores, metrics=None, ) - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + tokens_i = tokens[i].tolist() + scores_i = scores[i].tolist() + ranks_i = ranks[i].tolist() + token_ids = [row[0] for row in tokens_i[: accept_num[i]]] for batch_token_index in range(len(token_ids)): - result.outputs.logprob = float(scores[i, batch_token_index, 0]) - topk_token_ids = tokens[i, batch_token_index, :].tolist() - topk_logprobs = scores[i, batch_token_index, :].tolist() - sampled_rank = ranks[i, batch_token_index].item() + result.outputs.logprob = scores_i[batch_token_index][0] + topk_token_ids = tokens_i[batch_token_index] + topk_logprobs = scores_i[batch_token_index] + sampled_rank = ranks_i[batch_token_index] if result.outputs.draft_top_logprobs is None: result.outputs.draft_top_logprobs = LogprobsLists( @@ -704,16 +869,19 @@ def _process_batch_output(self): mtype = 3 if self.cfg.speculative_config.method: if self.use_logprobs: - mtype = int(self.output_tokens[1, 0].item()) + # meta[1] packs message_flag (low 8 bits) and actual_topk (high 24 bits). + packed_meta1 = int(self.output_tokens[1, 0].item()) + mtype = packed_meta1 & 0xFF + actual_topk = packed_meta1 >> 8 batch = self.output_tokens[2, 0] accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( [batch, MAX_DRAFT_TOKENS, K + 1] - ) + )[:, :, :actual_topk] scores = ( self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] .numpy() - .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + .reshape([batch, MAX_DRAFT_TOKENS, K + 1])[:, :, :actual_topk] ) ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) @@ -722,22 +890,47 @@ def _process_batch_output(self): batch_result = self._process_batch_draft_tokens(mtype, batch, accept_num, tokens, scores, ranks) self.postprocess(batch_result, mtype) return + # Pre-convert full arrays to Python lists once for MTP target token path. + tokens_lists = tokens.tolist() + scores_lists = scores.tolist() + ranks_list = ranks.tolist() else: batch = self.output_tokens[1] accept_num = tokens[2 : batch + 2] elif self.use_logprobs: - batch = self.output_tokens[1, 0] - tokens = tokens[2 : batch * (K + 1) + 2].reshape([batch, K + 1])[:, : (K + 1)] - scores = self.output_scores[: batch * (K + 1)].numpy().reshape([batch, K + 1])[:, : (K + 1)] + # mtext[1] packs bsz (low 16 bits) and actual_topk (high 16 bits). + # actual_topk = max_num_logprobs written by save_output_topk, which + # equals the actual number of logprob columns in this step's message + # (top_logprobs+1 across the batch). Using actual_topk as stride + # avoids processing the K+1=21 fixed-size slots when fewer are needed. + packed = int(self.output_tokens[1, 0]) + batch = packed & 0xFFFF + actual_topk = (packed >> 16) & 0xFFFF + tokens = tokens[2 : batch * actual_topk + 2].reshape([batch, actual_topk]) + scores = self.output_scores[: batch * actual_topk].numpy().reshape([batch, actual_topk]) ranks = self.output_ranks[:batch].numpy() + # Pre-convert the full [batch, actual_topk] arrays to Python lists once, + # avoiding per-row .tolist() calls inside the loop below. + tokens_lists = tokens.tolist() + scores_lists = scores.tolist() + ranks_list = ranks.tolist() else: batch = self.output_tokens[1, 0] tokens = tokens[2 : batch + 2] + # Receive sampling constraints per request from ZMQ side-channel (if enabled). + # The worker sends a dict {batch_id: sparse_vocab_indices} each step, + # where the value is a list[int] or list[list[int]] of allowed token ids + sampling_masks_per_request = {} + if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): + _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) + if mask_data is not None and isinstance(mask_data, dict): + sampling_masks_per_request = mask_data + batch_result = list() # reschedule for i in range(batch): - if self.resource_manager.stop_flags[i]: + if self.resource_manager.stop_flags[i] or self.resource_manager.tasks_list[i] is None: continue recovery_stop = False @@ -759,6 +952,7 @@ def _process_batch_output(self): if envs.ENABLE_V1_KVCACHE_SCHEDULER: if task_id in self.resource_manager.to_be_aborted_req_id_set: self.resource_manager.recycle_abort_task(task_id) + self._put_abort_results(task) if task_id in self.resource_manager.to_be_rescheduled_request_id_set: self.resource_manager.reschedule_preempt_task(task_id) continue @@ -768,7 +962,7 @@ def _process_batch_output(self): llm_logger.info(f"recovery stop signal found at task {task_id}") token_ids = [RECOVERY_STOP_SIGNAL] elif self.use_logprobs: - token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + token_ids = [row[0] for row in tokens_lists[i][: accept_num[i]]] else: token_ids = tokens[ 2 @@ -793,6 +987,7 @@ def _process_batch_output(self): and token_id == PREEMPTED_TOKEN_ID ): self.resource_manager.recycle_abort_task(task_id) + self._put_abort_results(task) llm_logger.info(f"sync abortion for request_id {task_id} done.") if ( task_id in self.resource_manager.to_be_rescheduled_request_id_set @@ -820,6 +1015,8 @@ def _process_batch_output(self): continue self.total_step += 1 + if task.status == RequestStatus.RUNNING_PREFILL: + task.status = RequestStatus.RUNNING_DECODE current_time = time.time() trace_carrier = None if self.tokens_counter[task_id] == 0: @@ -868,6 +1065,9 @@ def _process_batch_output(self): result.num_input_image_tokens = task.multimodal_inputs.get("num_input_image_tokens", 0) result.num_input_video_tokens = task.multimodal_inputs.get("num_input_video_tokens", 0) + if self.use_sampling_mask and i in sampling_masks_per_request: + result.outputs.sampling_mask = sampling_masks_per_request[i] + if is_prefill and len(token_ids) > 1: result.outputs.draft_token_ids = copy.deepcopy(token_ids) @@ -881,15 +1081,16 @@ def _process_batch_output(self): task.output_token_ids.append(token_id) if self.use_logprobs: if self.cfg.speculative_config.method: - result.outputs.logprob = float(scores[i, batch_token_index, 0]) - topk_token_ids = tokens[i, batch_token_index, :].tolist() - topk_logprobs = scores[i, batch_token_index, :].tolist() - sampled_rank = ranks[i, batch_token_index].item() + result.outputs.logprob = scores_lists[i][batch_token_index][0] + topk_token_ids = tokens_lists[i][batch_token_index] + topk_logprobs = scores_lists[i][batch_token_index] + sampled_rank = ranks_list[i][batch_token_index] else: - result.outputs.logprob = float(scores[i, 0]) - topk_token_ids = tokens[i, :].tolist() - topk_logprobs = scores[i, :].tolist() - sampled_rank = ranks[i].item() + # Use pre-converted lists (batch .tolist() done before the loop). + result.outputs.logprob = scores_lists[i][0] + topk_token_ids = tokens_lists[i] + topk_logprobs = scores_lists[i] + sampled_rank = ranks_list[i] if result.outputs.top_logprobs is None: result.outputs.top_logprobs = LogprobsLists( @@ -944,8 +1145,7 @@ def _process_batch_output(self): llm_logger.info(f"{self.resource_manager.info()}") if self.cfg.speculative_config.method: self._compute_speculative_status(result) - if not is_prefill: - self._record_completion_metrics(task, current_time) + self._record_completion_metrics(task, current_time) llm_logger.info(f"task {task_id} received eos token. Recycling.") if ( envs.ENABLE_V1_KVCACHE_SCHEDULER @@ -955,6 +1155,7 @@ def _process_batch_output(self): self.resource_manager.cache_output_tokens( task ) # when enable prefix caching, cache kv cache for output tokens + self._finalize_routing(task_id, task, result, is_prefill) self._recycle_resources(task_id, i, task, result, is_prefill) llm_logger.info(f"eos token {task_id} Recycle end.") break @@ -971,6 +1172,10 @@ def _record_metrics(self, task, current_time, token_ids): if hasattr(task, "last_token_time") and task.last_token_time is not None: token_gen_time = current_time - task.last_token_time main_process_metrics.time_per_output_token.observe(token_gen_time) + if self._benchmark_logger: + if not hasattr(task, "_itl_samples"): + task._itl_samples = [] + task._itl_samples.append(token_gen_time) task.last_token_time = current_time # Record generation metrics @@ -987,17 +1192,44 @@ def _record_first_token_metrics(self, task, current_time): def _record_completion_metrics(self, task, current_time): """Record metrics when request completes""" + role = self.cfg.scheduler_config.splitwise_role metrics = task.metrics - if metrics.engine_recv_first_token_time: - decode_time = current_time - metrics.engine_recv_first_token_time - main_process_metrics.request_decode_time.observe(decode_time) - trace_print(LoggingEventName.INFERENCE_END, task.request_id, getattr(task, "user", "")) + + if role in ("mixed", "decode"): + if metrics.engine_recv_first_token_time: + decode_time = current_time - metrics.engine_recv_first_token_time + main_process_metrics.request_decode_time.observe(decode_time) + trace_print(LoggingEventName.INFERENCE_END, task.request_id, getattr(task, "user", "")) + + if role == "prefill": + trace_print(LoggingEventName.PREFILL_INFERENCE_END, task.request_id, getattr(task, "user", "")) + elif role == "decode": + trace_print(LoggingEventName.DECODE_INFERENCE_END, task.request_id, getattr(task, "user", "")) + trace_print(LoggingEventName.POSTPROCESSING_START, task.request_id, getattr(task, "user", "")) - main_process_metrics.num_requests_running.dec(1) main_process_metrics.request_success_total.inc() main_process_metrics.request_inference_time.observe(current_time - metrics.inference_start_time) main_process_metrics.request_generation_tokens.observe(self.tokens_counter[task.request_id]) + if self._benchmark_logger: + from fastdeploy.metrics.benchmark_metrics_logger import ( + CompletedRequestRecord, + ) + + record = CompletedRequestRecord( + request_id=task.request_id, + completion_time=current_time, + arrival_time=metrics.arrival_time or 0.0, + inference_start_time=metrics.inference_start_time or 0.0, + first_token_time=metrics.engine_recv_first_token_time or 0.0, + last_token_time=metrics.engine_recv_latest_token_time or current_time, + input_len=getattr(task, "prompt_token_ids_len", 0) or 0, + output_len=self.tokens_counter[task.request_id], + num_cached_tokens=getattr(task, "num_cached_tokens", 0) or 0, + itl_samples=getattr(task, "_itl_samples", []), + ) + self._benchmark_logger.on_request_completed(record) + def _record_speculative_decoding_metrics(self, accept_num): """Record metrics of speculative decoding""" if not hasattr(main_process_metrics, "spec_decode_draft_acceptance_rate"): @@ -1054,6 +1286,33 @@ def _record_speculative_decoding_accept_num_per_request(self, req_id, accept_num self.accept_token_num_per_head_per_request[req_id][i] += 1 self.accept_token_num_per_head[i] += 1 + def _put_abort_results(self, task): + now = time.time() + eos_token_ids = getattr(task, "eos_token_ids", [0]) + abort_metrics = copy.copy(task.metrics) + for field in ( + "arrival_time", + "inference_start_time", + "engine_recv_latest_token_time", + "engine_recv_first_token_time", + "request_start_time", + ): + if not getattr(abort_metrics, field): + setattr(abort_metrics, field, now) + result = RequestOutput( + request_id=task.request_id, + finished=True, + outputs=CompletionOutput( + index=0, + send_idx=self.tokens_counter.get(task.request_id), + token_ids=[eos_token_ids[0]], + ), + metrics=abort_metrics, + error_code=200, + error_msg="Aborted", + ) + self.cached_generated_tokens.put_results([result]) + def clear_data(self): if envs.ENABLE_V1_KVCACHE_SCHEDULER: self.resource_manager.clear_data() @@ -1076,6 +1335,7 @@ def clear_data(self): ), ) is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" + self._finalize_routing(task.request_id, task, result, is_prefill) self._recycle_resources(task.request_id, i, task, result, is_prefill) llm_logger.warning(f"clear data for task {task.request_id}") diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index 46e686c38a6..5a4666d46b5 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -14,19 +14,21 @@ # limitations under the License. """ +import asyncio import gc import glob import os import re import time from multiprocessing.shared_memory import SharedMemory -from typing import Any, Dict, List +from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np import paddle import yaml from paddleformers.utils.log import logger +from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus @@ -52,10 +54,15 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int): self.model_list = models self._capture_model_state() self.rdma_handle = None - if self.load_config.load_strategy == "rsync": - self.update_weights_by_rdma() + self.use_gdr_checkpoint_transfer = envs.FD_USE_GDR_CHECKPOINT_TRANSFER + + if self.use_gdr_checkpoint_transfer: + self.update_weights_by_gdr() else: - self.update_parameters() + if self.load_config.load_strategy == "rsync": + self.update_weights_by_rdma() + else: + self.update_parameters() self.finalize_update() logger.info( @@ -64,20 +71,26 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int): ) @paddle.no_grad() - def _capture_model_state(self): + def _capture_model_state(self, log_params: bool = True): """Capture and store initial model parameters state.""" + self.state_dict = {} for model in self.model_list: for name, param in model.state_dict().items(): - logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}") + if log_params: + logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}") self.state_dict[name] = param - def update_weights_by_rdma(self, version: str = None, verify_checksum: bool = False): + def update_weights_by_rdma( + self, + version: str = None, + verify_checksum: bool = False, + ): def valid_parameters(old_state_dict, new_state_dict): is_valid = True - for key in old_state_dict: - if key not in new_state_dict: + for key in new_state_dict: + if key not in old_state_dict: is_valid = False - logger.error(f"Invalid parameter: {key} not in new_state_dict") + logger.error(f"Invalid parameter: {key} not in old_state_dict") elif old_state_dict[key].shape != new_state_dict[key].shape: is_valid = False logger.error( @@ -92,14 +105,7 @@ def valid_parameters(old_state_dict, new_state_dict): ) return is_valid - bootstrap_load = version is None or version == "" - if bootstrap_load: - version = self.read_model_version_from_file() - if version is None or version == "": - raise Exception( - "rsync model version not set, please set it in 1) {model_version}/version.yaml " - "or 2) interface arguments 'version'" - ) + version, bootstrap_load = self._resolve_weight_update_version(version) logger.info( f"START rank:{self.local_rank}/{self.nranks} update_weights_by_rdma, " @@ -128,8 +134,8 @@ def valid_parameters(old_state_dict, new_state_dict): raise ValueError(error_msg) update_start = time.perf_counter() - for name, target_param in old_state_dict.items(): - new_param = new_state_dict[name] + for name, new_param in new_state_dict.items(): + target_param = old_state_dict[name] if bootstrap_load and not target_param._is_initialized(): new_param = new_param.cuda() new_param._share_buffer_to(target_param) @@ -151,6 +157,164 @@ def valid_parameters(old_state_dict, new_state_dict): "rank": self.local_rank, } + def update_weights_by_gdr( + self, version: str = None, verify_checksum: bool = False, restore_cleared_params: bool = False + ): + """Unified weight update via CheckpointTransfer (supports GDR and IPC backends).""" + config = dict(self.fd_config.load_config.rsync_config or {}) + is_ipc = self.load_config.load_strategy != "rsync" + + if is_ipc: + step_id = version or "0" + else: + version, _ = self._resolve_weight_update_version(version) + step_id = version + + logger.info( + f"START rank:{self.local_rank}/{self.nranks} update_weights_by_gdr, " + f"load_strategy:{self.load_config.load_strategy}, step_id:{step_id}" + ) + + from checkpoint_transfer.transfer import CheckpointTransfer + + transfer_config = self._build_ct_transfer_config(config) + logger.info(f"CheckpointTransfer config:{transfer_config}") + ct_handle = CheckpointTransfer(transfer_config) + + total_start = time.perf_counter() + asyncio.run(ct_handle.initialize()) + try: + weights_iterator = ct_handle.receive_weights_sync(step_id=step_id, output_framework="paddle") + + if restore_cleared_params: + for name, target_param in self.state_dict.items(): + if not target_param._is_initialized(): + paddle.empty(target_param.shape, dtype=target_param.dtype)._share_buffer_to(target_param) + logger.debug(f"Restored cleared parameter storage before GDR checkpoint transfer load: {name}") + update_count, mtp_cache_count = self._load_models_from_weight_iterator(weights_iterator) + finally: + asyncio.run(ct_handle.cleanup()) + self._capture_model_state(log_params=False) + total_cost = time.perf_counter() - total_start + logger.info( + f"END update_weights_by_gdr, cost {total_cost:.2f} seconds, " + f"weights:{update_count}, mtp_cached_weights:{mtp_cache_count}, " + f"step_id:{step_id}, local_rank:{self.local_rank}" + ) + return { + "update_cost": total_cost, + "total_cost": total_cost, + "version": step_id, + "rank": self.local_rank, + "update_count": update_count, + "mtp_cache_count": mtp_cache_count, + } + + def _build_ct_transfer_config(self, config: dict): + from dataclasses import fields + + from checkpoint_transfer.config import Phase1Backend, Role, TransferConfig + + transfer_config = dict(config) + if "device_name" in transfer_config and "device" not in transfer_config: + transfer_config["device"] = transfer_config.pop("device_name") + else: + transfer_config.pop("device_name", None) + + transfer_config["role"] = Role.INFERENCE + + if self.load_config.load_strategy == "rsync": + node_index = int(transfer_config.pop("index", 0)) + transfer_config["global_rank"] = node_index * self.nranks + self.local_rank + transfer_config["phase1_backend"] = Phase1Backend.GPU_DIRECT + transfer_config["group_size"] = int(transfer_config.get("group_size", self.nranks)) + else: + transfer_config.pop("index", None) + gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) + transfer_config["global_rank"] = gpu_id + transfer_config["phase1_backend"] = Phase1Backend.IPC + transfer_config["group_size"] = int(transfer_config.get("group_size", self.nranks)) + transfer_config["qsize"] = int(transfer_config.get("qsize", 2)) + + transfer_config_keys = {field.name for field in fields(TransferConfig)} + transfer_config = {key: value for key, value in transfer_config.items() if key in transfer_config_keys} + return TransferConfig(**transfer_config) + + def _resolve_weight_update_version(self, version: Optional[str]) -> Tuple[str, bool]: + bootstrap_load = version is None or version == "" + if bootstrap_load: + version = self.read_model_version_from_file() + if version is None or version == "": + raise Exception( + "rsync model version not set, please set it in 1) {model_version}/version.yaml " + "or 2) interface arguments 'version'" + ) + return version, bootstrap_load + + def _load_models_from_weight_iterator( + self, + weights_iterator: Iterable[Tuple[str, Any]], + ) -> Tuple[int, int]: + update_count = 0 + + if len(self.model_list) == 1: + + def count_weights(): + nonlocal update_count + for item in weights_iterator: + update_count += 1 + yield item + + self.model_list[0].load_weights(count_weights()) + return update_count, 0 + + mtp_models = self.model_list[1:] + config = self.fd_config.load_config.rsync_config or {} + mtp_chunk_size = max(1, int(config.get("gdr_mtp_chunk_size", 16))) + mtp_chunk: List[Tuple[str, Any]] = [] + mtp_cache_count = 0 + mtp_weight_tokens = ["mtp_", "mtp_block"] + for model in mtp_models: + model_config = getattr(getattr(model, "fd_config", None), "model_config", None) + start_layer = getattr(model, "mtp_start_layer_idx", None) + num_layers = getattr(model, "num_mtp_layers", None) + start_layer = start_layer if start_layer is not None else getattr(model_config, "start_layer_index", None) + num_layers = ( + num_layers if num_layers is not None else getattr(model_config, "num_nextn_predict_layers", None) + ) + if start_layer is None or num_layers is None: + continue + for layer_id in range(int(start_layer), int(start_layer) + int(num_layers)): + mtp_weight_tokens.append(f"layers.{layer_id}.") + mtp_weight_tokens.append(f".layers.{layer_id}.") + + def flush_mtp_chunk(): + nonlocal mtp_chunk + if not mtp_chunk: + return + for model in mtp_models: + model.load_weights(iter(mtp_chunk)) + mtp_chunk = [] + + def cache_mtp_weights(): + nonlocal update_count, mtp_cache_count + for item in weights_iterator: + name, _ = item + update_count += 1 + if any(token in name for token in mtp_weight_tokens): + mtp_chunk.append(item) + mtp_cache_count += 1 + yield item + if len(mtp_chunk) >= mtp_chunk_size: + flush_mtp_chunk() + + self.model_list[0].load_weights(cache_mtp_weights()) + flush_mtp_chunk() + if mtp_cache_count == 0: + raise ValueError("No MTP weights were cached from the GDR stream for auxiliary model loading.") + + return update_count, mtp_cache_count + def update_parameters(self, pid: int = 0, restart_process_group=False) -> None: """Core method to update model parameters based on strategy.""" start_time = time.perf_counter() @@ -348,6 +512,13 @@ def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None: if shutdown_process_group: paddle.distributed.shutdown_process_group(self.parallel_config.ep_group) if shutdown_process_group: + # ProcessGroupGloo has no shutdown(); remove it from paddle's registry + # before the global sweep to avoid AttributeError. + from paddle.distributed.collective import _get_group_map_by_name + + for name, pg in list(_get_group_map_by_name().items()): + if pg.process_group is not None and not hasattr(pg.process_group, "shutdown"): + _get_group_map_by_name().pop(name, None) paddle.distributed.shutdown_process_group() self._update_shared_status(pid, ModelWeightsStatus.CLEARED) @@ -407,7 +578,7 @@ def _validate_parameter_match(self, name: str, src: paddle.Tensor, dst: paddle.T if src.shape != dst.shape: raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}") - def finalize_update(self, pid: int = 0): + def finalize_update(self, pid: Optional[int] = None): """Finalize update process with verification.""" self._verify_parameters("update") @@ -472,8 +643,10 @@ def _log_memory(self, context: str): f"current_reserved: {curr_reserved:.2f}GB" ) - def _update_shared_status(self, pid: int, status: int) -> None: + def _update_shared_status(self, pid: Optional[int], status: int) -> None: """Update shared memory status flag for inter-process communication.""" + if pid is None: + pid = self.parallel_config.local_engine_worker_queue_port array = np.zeros([1], dtype=np.int32) shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}") value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf) diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index cade1355088..59a7822c3ff 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -14,8 +14,9 @@ # limitations under the License. """ -from typing import Any, Dict, Optional +from typing import Dict, Optional, Union +from fastdeploy.utils import parse_quantization from fastdeploy.worker.worker_process import initialize_fd_config @@ -54,7 +55,7 @@ def __init__( expert_parallel_size: int = 1, enable_expert_parallel: bool = False, ori_vocab_size: int = None, - quantization: Optional[Dict[str, Any]] = None, + quantization: Optional[Union[Dict, str]] = None, guided_decoding_backend: str = "off", disable_any_whitespace: bool = True, enable_logprob: bool = False, @@ -68,6 +69,7 @@ def __init__( routing_replay_config: str = None, load_choices: str = "default_v1", lm_head_fp32: bool = False, + moe_gate_fp32: bool = True, ): # Required parameters self.model = model_name_or_path @@ -107,7 +109,7 @@ def __init__( self.enable_expert_parallel = enable_expert_parallel self.data_parallel_size = data_parallel_size self.ori_vocab_size = ori_vocab_size - self.quantization = quantization + self.quantization = parse_quantization(quantization) self.guided_decoding_backend = guided_decoding_backend self.disable_any_whitespace = disable_any_whitespace self.enable_logprob = enable_logprob @@ -121,6 +123,7 @@ def __init__( self.routing_replay_config = routing_replay_config self.load_choices = load_choices self.lm_head_fp32 = lm_head_fp32 + self.moe_gate_fp32 = moe_gate_fp32 def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index 960a64e7f58..1e1adf5fd9b 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -18,7 +18,7 @@ import aiohttp import uvicorn from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse +from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastdeploy.router.utils import ( InstanceInfo, @@ -29,6 +29,7 @@ from fastdeploy.utils import router_logger as logger app = FastAPI() +_background_tasks = set() @dataclass @@ -49,6 +50,14 @@ class RouterArgs: """ Request timeout in seconds """ + preempt_retry_count: int = 3 + """ + Max retry count when decode instance preempts a request in splitwise mode. + """ + preempt_retry_exclude_decode: bool = False + """ + Whether to exclude the previously used decode instance when retrying after preemption. + """ @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -76,6 +85,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=RouterArgs.request_timeout_secs, help="Request timeout in seconds", ) + parser.add_argument( + "--preempt-retry-count", + type=int, + default=RouterArgs.preempt_retry_count, + help="Max retry count when decode instance preempts a request in splitwise mode.", + ) + parser.add_argument( + "--preempt-retry-exclude-decode", + action="store_true", + default=RouterArgs.preempt_retry_exclude_decode, + help="Whether to exclude the previously used decode instance when retrying after preemption.", + ) return parser @@ -91,6 +112,8 @@ def __init__(self, args): self.port = args.port self.splitwise = args.splitwise self.timeout = args.request_timeout_secs + self.preempt_retry_count = args.preempt_retry_count + self.preempt_retry_exclude_decode = args.preempt_retry_exclude_decode self.mixed_servers = [] self.prefill_servers = [] @@ -152,16 +175,21 @@ async def get_decode_instances(self, version: Optional[str] = None) -> List[Dict instances = [inst for inst in instances if inst.version == version] return [inst.to_dict() for inst in instances] - async def select_pd(self): - """Select one prefill and one decode server""" + async def select_pd(self, exclude_decode=None): + """Select one prefill and one decode server, optionally excluding a decode instance.""" async with self.lock: if not self.prefill_servers: raise RuntimeError(f"No prefill servers available (decode={len(self.decode_servers)})") if not self.decode_servers: raise RuntimeError(f"No decode servers available (prefill={len(self.prefill_servers)})") pidx = random.randint(0, len(self.prefill_servers) - 1) - didx = random.randint(0, len(self.decode_servers) - 1) - return self.prefill_servers[pidx], self.decode_servers[didx] + available_decode = ( + [d for d in self.decode_servers if d is not exclude_decode] if exclude_decode else self.decode_servers + ) + if not available_decode: + available_decode = self.decode_servers + didx = random.randint(0, len(available_decode) - 1) + return self.prefill_servers[pidx], available_decode[didx] async def select_mixed(self): """Select one mixed server""" @@ -191,57 +219,108 @@ async def handle_mixed_request(self, request_data: dict, endpoint_name: str): async def handle_splitwise_request(self, request_data: dict, endpoint_name: str): logger.debug(f"Received request: {request_data}") - prefill_server, decode_server = await self.select_pd() - logger.debug(f"Selected prefill server: {prefill_server}") - logger.debug(f"Selected decode server: {decode_server}") - - if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1: - raise HTTPException( - status_code=400, - detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1", + last_decode_server = None + # Preserve client request_id on first attempt; append retry suffix on subsequent attempts + base_request_id = request_data.get("request_id") or str(uuid4()) + max_attempts = self.preempt_retry_count + 1 + completion_token_ids = [] + + for attempt in range(max_attempts): + prefill_server, decode_server = await self.select_pd( + exclude_decode=last_decode_server if self.preempt_retry_exclude_decode else None ) + logger.debug(f"Selected prefill server: {prefill_server}, decode server: {decode_server}") - # TODO: unify the disaggregate_info in server and remove redundancy params - is_same_node = prefill_server.host_ip == decode_server.host_ip - is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol - is_same_tp_size = prefill_server.tp_size == decode_server.tp_size - use_ipc = is_same_node and is_support_ipc and is_same_tp_size - - disaggregate_info = { - "prefill_ip": prefill_server.host_ip, - "decode_ip": decode_server.host_ip, - "prefill_connector_port": prefill_server.connector_port, - "decode_connector_port": decode_server.connector_port, - "decode_device_ids": decode_server.device_ids, - "decode_rdma_ports": decode_server.rdma_ports, - "transfer_protocol": "ipc" if use_ipc else "rdma", - "decode_tp_size": decode_server.tp_size, - } - - modified_request = request_data.copy() - modified_request["disaggregate_info"] = disaggregate_info - if "request_id" not in modified_request: - modified_request["request_id"] = str(uuid4()) + if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1: + raise HTTPException( + status_code=400, + detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1", + ) - logger.debug(f"Modified request: {modified_request}") + # TODO: unify the disaggregate_info in server and remove redundancy params + is_same_node = prefill_server.host_ip == decode_server.host_ip + is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol + is_same_tp_size = prefill_server.tp_size == decode_server.tp_size + use_ipc = is_same_node and is_support_ipc and is_same_tp_size + + disaggregate_info = { + "prefill_ip": prefill_server.host_ip, + "decode_ip": decode_server.host_ip, + "prefill_connector_port": prefill_server.connector_port, + "decode_connector_port": decode_server.connector_port, + "decode_device_ids": decode_server.device_ids, + "decode_rdma_ports": decode_server.rdma_ports, + "transfer_protocol": "ipc" if use_ipc else "rdma", + "decode_tp_size": decode_server.tp_size, + } + + modified_request = request_data.copy() + modified_request["disaggregate_info"] = disaggregate_info + if completion_token_ids: + modified_request["completion_token_ids"] = completion_token_ids + if attempt == 0: + modified_request["request_id"] = base_request_id + else: + modified_request["request_id"] = f"{base_request_id}-retry{attempt}" - if request_data.get("stream", False): - return await self._generate_stream( - modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name - ) - else: - return await self._generate( - modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name - ) + logger.debug(f"Modified request: {modified_request}") - async def _generate( + if request_data.get("stream", False): + return await self._generate_stream( + modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name + ) + else: + ret_json, status_code = await self._do_generate( + modified_request, [prefill_server.url(), decode_server.url()], endpoint=endpoint_name + ) + logger.debug(f"Get response of req {modified_request['request_id']}: {ret_json}") + + if self._is_need_reschedule(ret_json): + last_decode_server = decode_server + choices = ret_json.get("choices", []) + if choices: + completion_token_ids.extend(choices[0].get("message", {}).get("completion_token_ids", [])) + + logger.warning( + f"Preemption detected on attempt {attempt+1}/{max_attempts}, " + f"decode={decode_server.url()}, req_id {modified_request['request_id']}," + f"retrying with new PD instances..." + ) + else: + break + + logger.debug(f"Return response of req_id {base_request_id}: {ret_json}") + return ORJSONResponse(content=ret_json, status_code=status_code) + + def _is_need_reschedule(self, ret_json: dict) -> bool: + # ChatCompletionResponse format: choices[0].finish_reason == "pd_reschedule" + choices = ret_json.get("choices", []) + if choices: + finish_reason = choices[0].get("finish_reason", "") + if finish_reason == "pd_reschedule": + logger.debug(f"PD reschedule request, ret_json: {ret_json}") + return True + # ErrorResponse format compatibility + error = ret_json.get("error", {}) + if isinstance(error, dict) and "PD Error" in str(error.get("message", "")): + return True + return False + + async def _do_generate( self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" - ) -> ORJSONResponse: + ) -> tuple: + """Send requests and return (ret_json, status_code).""" async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=self.timeout)) as session: tasks = [session.post(f"{url}/{endpoint}", json=modified_request) for url in urls] results = await asyncio.gather(*tasks) ret_json = await results[return_result_url_index].json() - return ORJSONResponse(content=ret_json, status_code=results[return_result_url_index].status) + return ret_json, results[return_result_url_index].status + + async def _generate( + self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" + ) -> ORJSONResponse: + ret_json, status_code = await self._do_generate(modified_request, urls, return_result_url_index, endpoint) + return ORJSONResponse(content=ret_json, status_code=status_code) async def _generate_stream( self, modified_request, urls, return_result_url_index=-1, endpoint="v1/chat/completions" @@ -510,39 +589,15 @@ async def abort_requests(request: Request): decode_servers = app.state.router.decode_servers all_servers = prefill_servers + decode_servers - async with aiohttp.ClientSession() as session: - tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] - responses = await asyncio.gather(*tasks, return_exceptions=True) - - # Aggregate results from Node D only - all_aborted = [] - all_not_found = [] - errors = [] - decode_start = len(prefill_servers) - for i, (server, resp) in enumerate(zip(all_servers, responses)): - if i < decode_start: - continue - if isinstance(resp, Exception): - errors.append({"server": server.url(), "error": str(resp)}) - elif resp.status == 200: - data = await resp.json() - result = data.get("result") or {} - all_aborted.extend(result.get("aborted", [])) - all_not_found.extend(result.get("not_found", [])) - else: - errors.append({"server": server.url(), "status": resp.status}) - - return JSONResponse( - content={ - "request_id": f"router-{uuid4()}", - "status": "success" if not errors else "error", - "error_message": None if not errors else str(errors), - "result": { - "aborted": all_aborted, - "not_found": list(set(all_not_found)), - }, - } - ) + async def _forward_abort(): + async with aiohttp.ClientSession() as session: + tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] + await asyncio.gather(*tasks, return_exceptions=True) + + task = asyncio.create_task(_forward_abort()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + return Response(status_code=200) def launch_router(router_args: RouterArgs): diff --git a/fastdeploy/scheduler/config.py b/fastdeploy/scheduler/config.py index 1422b2635f3..90016479063 100644 --- a/fastdeploy/scheduler/config.py +++ b/fastdeploy/scheduler/config.py @@ -273,6 +273,7 @@ def __init__(self, args): self.max_num_seqs = 34 self.splitwise_role = "mixed" self.enable_overlap_schedule = False + self.enable_moe_scores_elementwise_fuse = False self.config = None for key, value in args.items(): diff --git a/fastdeploy/scheduler/dp_scheduler.py b/fastdeploy/scheduler/dp_scheduler.py index 2339a077c96..f5b03eba30f 100644 --- a/fastdeploy/scheduler/dp_scheduler.py +++ b/fastdeploy/scheduler/dp_scheduler.py @@ -23,7 +23,7 @@ from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledResponse from fastdeploy.scheduler.local_scheduler import LocalScheduler -from fastdeploy.utils import get_logger +from fastdeploy.utils import envs, get_logger class DPLocalScheduler(LocalScheduler): @@ -131,19 +131,52 @@ def get_requests( Returns: List of Request objects ready for processing """ - # DP scheduler is used in V1, there is no need to manage request fetching in the scheduler, resource_manager_v1 will do that. + if available_blocks <= reserved_output_blocks or batch < 1: + self.scheduler_logger.debug( + f"Scheduler's resource are insufficient: available_blocks={available_blocks} " + f"reserved_output_blocks={reserved_output_blocks} batch={batch} " + f"max_num_batched_tokens={max_num_batched_tokens}" + ) + return [] + required_total_blocks = 0 + current_prefill_tokens = 0 + start_batch_time = time.time() requests: List[Request] = [] with self.requests_not_empty: - batch_ids = self.requests_not_empty.wait_for( - lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + 1], - 0.005, - ) - if batch_ids: - for request_id in batch_ids: - request = self.requests[request_id] - requests.append(request.raw) - self.ids_read_cursor += 1 + while True: + batch_ids = self.requests_not_empty.wait_for( + lambda: self.ids[self.ids_read_cursor : self.ids_read_cursor + batch], + 0.005, + ) + if batch_ids: + for request_id in batch_ids: + request = self.requests[request_id] + required_input_blocks = self.calc_required_blocks(request.prompt_tokens_ids_len, block_size) + current_prefill_tokens += request.prompt_tokens_ids_len + required_total_blocks += required_input_blocks + reserved_output_blocks + if required_total_blocks > available_blocks: + break + + requests.append(request.raw) + self.ids_read_cursor += 1 + start_batch_time = time.time() + if current_prefill_tokens > max_num_batched_tokens: + break + if len(requests) >= batch: + break + if ( + (current_prefill_tokens > max_num_batched_tokens) + or (len(requests) >= batch) + or (time.time() - start_batch_time > envs.FD_EP_BATCHED_TOKEN_TIMEOUT) + ): + break + + if batch_ids: + if len(batch_ids) > 0 and len(requests) == 0: + self.scheduler_logger.debug( + f"Scheduler has put all just-pulled request into the queue: {len(batch_ids)}" + ) if len(requests) > 0: self.scheduler_logger.info( diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index fc4a64686b5..8fca9a4690d 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -129,8 +129,11 @@ def _recycle(self, request_id: Optional[str] = None): if request_id is not None: self.requests.pop(request_id, None) self.responses.pop(request_id, None) - self.ids.pop(self.ids.index(request_id)) - self.ids_read_cursor -= 1 + idx = self.ids.index(request_id) + self.ids.pop(idx) + if idx < self.ids_read_cursor: + self.ids_read_cursor -= 1 + scheduler_logger.debug(f"request_id : {request_id} has been recycled") return if self.max_size <= 0: diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index fa50eae462a..08553411188 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -71,7 +71,7 @@ def __init__(self, fd_config: "FDConfig"): self.max_ngram_size = self.speculative_config.max_ngram_size self.min_ngram_size = self.speculative_config.min_ngram_size - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime spec_logger.info(f"Speculate config: {self.speculative_config}") @@ -118,7 +118,7 @@ def prepare_dummy_speculative_drafts( stop = share_inputs["stop_flags"][0].item() if not stop: - share_inputs["draft_tokens"][:batch_size, :max_fake_drafts] = 5 + share_inputs["draft_tokens"][:batch_size, : max_fake_drafts + 1] = 5 share_inputs["seq_lens_this_time"][:batch_size] = max_fake_drafts + 1 else: share_inputs["seq_lens_this_time"][:batch_size] = 0 diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 4ec57e93594..41b0da93819 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -49,7 +49,10 @@ share_external_data, update_attn_mask_offsets, ) + + # temporary solution from fastdeploy.model_executor.xpu_pre_and_post_process import ( + async_set_value, xpu_pre_process, xpu_process_output, ) @@ -62,7 +65,6 @@ eagle_get_self_hidden_states, eagle_gather_hidden_states, hybrid_mtp_ngram, - mtp_save_first_token, mtp_step_paddle, share_external_data, speculate_get_logits, @@ -103,7 +105,7 @@ def __init__( self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id - self.use_attn_mask_offset = self.enable_mm and self.fd_config.deploy_modality != "text" + self.use_attn_mask_offset = self.enable_mm self._update_mtp_config(main_model) self._load_model() @@ -243,10 +245,12 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. - # 2. If no need to profile, create kv cache if cache managers do not exist. + # 2. If no need to profile, create kv cache unless kvcache_storage_backend or + # p/d disaggregation is enabled. Note: CPU cache (num_cpu_blocks > 0) does NOT + # prevent GPU runner from creating GPU cache tensors; cache transfer manager + # handles CPU<->GPU swap on top of the GPU tensors created here. create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) @@ -403,6 +407,19 @@ def _initialize_attn_backend( self.target_model_inputs["kv_num_blocks_x_cpu"] ).cpu() + # Decode attention split ops buffers + if ( + "decode_block_indices" in self.target_model_inputs + and self.target_model_inputs["decode_block_indices"] is not None + ): + self.model_inputs["decode_block_indices"] = self.target_model_inputs["decode_block_indices"] + + self.model_inputs["decode_num_blocks"] = self.target_model_inputs["decode_num_blocks"] + self.model_inputs["decode_chunk_size"] = self.target_model_inputs["decode_chunk_size"] + self.model_inputs["decode_tmp_workspace"] = self.target_model_inputs["decode_tmp_workspace"] + self.model_inputs["decode_tmp_m"] = self.target_model_inputs["decode_tmp_m"] + self.model_inputs["decode_tmp_d"] = self.target_model_inputs["decode_tmp_d"] + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( @@ -424,8 +441,7 @@ def clear_mtp_cache(self, profile=False): Clear allocated cacheKV """ create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) if not create_cache_tensor: @@ -483,28 +499,32 @@ def insert_tasks_v1( input_ids = request.prompt_token_ids + request.output_token_ids self.model_inputs["input_ids_len"][idx] = length - 1 - self.model_inputs["pre_ids"][idx : idx + 1] = -1 + async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1) self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ idx : idx + 1, 1:length ] - self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ - "input_ids" - ][idx : idx + 1, 1:length].cpu() + # TODO: use token_all_ids replace with input_ids_cpu + if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs: + self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ + "input_ids" + ][idx : idx + 1, 1:length].cpu() encoder_block_num = len(request.block_tables) - self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num - self.model_inputs["block_tables"][idx : idx + 1, :] = -1 - self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( - request.block_tables, dtype="int32" + async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) + async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) + async_set_value( + self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables ) - self.model_inputs["stop_flags"][idx : idx + 1] = False - self.model_inputs["batch_drop"][idx : idx + 1] = False - self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length + async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], False) + async_set_value(self.model_inputs["batch_drop"][idx : idx + 1], False) + + async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], length) self.exist_prefill_flag = True - self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index - self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length - self.model_inputs["step_idx"][idx : idx + 1] = ( - len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index) + async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length) + async_set_value( + self.model_inputs["step_idx"][idx : idx + 1], + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0, ) if self.use_attn_mask_offset: inputs = request.multimodal_inputs @@ -522,18 +542,19 @@ def insert_tasks_v1( if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generates first token - self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0) self.exist_prefill_flag = False - self.model_inputs["recompute_token_num"][idx : idx + 1] = 0 - self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length + 1 + async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length + 1) # NOTE(liuzichang): # extra 1 : P-D split need rollback one step - self.model_inputs["mask_rollback"][idx : idx + 1] = 1 + + async_set_value(self.model_inputs["recompute_token_num"][idx : idx + 1], 0) + async_set_value(self.model_inputs["mask_rollback"][idx : idx + 1], 1) # has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task encoder_block_num = len(request.block_tables) - self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num - self.model_inputs["block_tables"][idx : idx + 1, :] = -1 + async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) + async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) if current_platform.is_cuda(): async_set_value( self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables @@ -542,16 +563,13 @@ def insert_tasks_v1( self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) - # if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode - # has_decode_task = True - # continue else: - self.model_inputs["block_tables"][idx : idx + 1, :] = -1 - self.model_inputs["stop_flags"][idx : idx + 1] = True - self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0 - self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 - self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 - self.model_inputs["is_block_step"][idx : idx + 1] = False + async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1) + async_set_value(self.model_inputs["stop_flags"][idx : idx + 1], True) + async_set_value(self.model_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0) + async_set_value(self.model_inputs["seq_lens_decoder"][idx : idx + 1], 0) + async_set_value(self.model_inputs["seq_lens_encoder"][idx : idx + 1], 0) + async_set_value(self.model_inputs["is_block_step"][idx : idx + 1], False) continue # TODO(liuzichang): Solve splitewise-p bug to restore @@ -673,6 +691,15 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_ru attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None, ) + # Decode attention split ops buffers (assigned after construction due to ForwardMeta __getattr__) + if "decode_block_indices" in self.model_inputs: + self.forward_meta.decode_block_indices = self.model_inputs["decode_block_indices"] + self.forward_meta.decode_num_blocks = self.model_inputs["decode_num_blocks"] + self.forward_meta.decode_chunk_size = self.model_inputs["decode_chunk_size"] + self.forward_meta.decode_tmp_workspace = self.model_inputs["decode_tmp_workspace"] + self.forward_meta.decode_tmp_m = self.model_inputs["decode_tmp_m"] + self.forward_meta.decode_tmp_d = self.model_inputs["decode_tmp_d"] + # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) @@ -835,23 +862,26 @@ def _post_process(self, sampled_token_ids): ) if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: - skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) - recover_model_output_map = recover_batch_index_for_output( - self.model_inputs, - self.model_inputs.index_to_batch_id, - self.model_inputs.enable_pd_reorder, - ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], - ) - mtp_save_first_token( - recover_model_output_map["base_model_draft_tokens"], - self.model_inputs["not_need_stop"], - recover_model_output_map["seq_lens_decoder"], - recover_model_output_map["prompt_lens"], - recover_model_output_map["step_idx"], - self.local_rank, - self.parallel_config.use_ep, - skip_save, - ) + if current_platform.is_xpu(): + # Note(wangyanpeng): mtp_save_first_token for GPU platforms has been moved to model_runner. + # Only XPU platform is retained here. + skip_save = bool(int(envs.ENABLE_V1_KVCACHE_SCHEDULER)) + recover_model_output_map = recover_batch_index_for_output( + self.model_inputs, + self.model_inputs.index_to_batch_id, + self.model_inputs.enable_pd_reorder, + ["base_model_draft_tokens", "seq_lens_decoder", "prompt_lens", "step_idx"], + ) + mtp_save_first_token( + recover_model_output_map["base_model_draft_tokens"], + self.model_inputs["not_need_stop"], + recover_model_output_map["seq_lens_decoder"], + recover_model_output_map["prompt_lens"], + recover_model_output_map["step_idx"], + self.local_rank, + self.parallel_config.use_ep, + skip_save, + ) # Ensure only save first token once. paddle.assign( paddle.where( @@ -880,7 +910,7 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F token_num_cpu = self.model_inputs["seq_lens_this_time"].numpy().sum().item() else: if substep == 0: - token_num_cpu = real_bsz * (self.max_draft_token_num + 1) + token_num_cpu = self.model_inputs["target_hidden_states"].shape[0] else: token_num_cpu = real_bsz if token_num_cpu > 0: diff --git a/fastdeploy/spec_decode/suffix.py b/fastdeploy/spec_decode/suffix.py index f4d1495524c..e11c2255a3e 100644 --- a/fastdeploy/spec_decode/suffix.py +++ b/fastdeploy/spec_decode/suffix.py @@ -43,7 +43,7 @@ def __init__(self, fd_config: "FDConfig"): if SuffixDecodingCache is None: raise ImportError( - "arctic_inference.suffix_decoding is not available. Please install arctic-inference package." + "arctic_inference.suffix_decoding is not available. Please install via `pip install arctic-inference==0.1.2`." ) # Initialize SuffixDecodingCache diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index e64e468b186..5c2f793fdbf 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -53,9 +53,6 @@ def _get_current_server_info(self): available_batch_size = min(self.cfg.max_prefill_batch, self.engine.resource_manager.available_batch()) available_block_num = self.engine.resource_manager.available_block_num() - unhandled_request_num = self.engine.scheduler.get_unhandled_request_num() - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - unhandled_request_num = max(unhandled_request_num, len(self.engine.resource_manager.waiting)) server_info = { "splitwise_role": self.cfg.scheduler_config.splitwise_role, "block_size": int(self.cfg.cache_config.block_size), @@ -65,7 +62,7 @@ def _get_current_server_info(self): "available_resource": float(1.0 * available_block_num / self.cfg.cache_config.total_block_num), "max_batch_size": int(available_batch_size), "max_input_token_num": self.cfg.model_config.max_model_len, - "unhandled_request_num": unhandled_request_num, + "unhandled_request_num": self.engine.scheduler.get_unhandled_request_num(), "available_batch": int(self.engine.resource_manager.available_batch()), } return server_info diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 7200c99ed9c..acbe71411da 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -15,10 +15,11 @@ """ import pickle +import threading import time import traceback from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List +from typing import Dict, List, Tuple import zmq @@ -58,6 +59,8 @@ def __init__(self, cfg, worker_queue, resource_manager): if self.cfg.scheduler_config.splitwise_role != "mixed": self.zmq_ctx = zmq.Context() self.push_sockets: Dict[str, zmq.Socket] = {} + self._push_socket_locks: Dict[str, threading.Lock] = {} + self._push_sockets_meta_lock = threading.Lock() self.pull_socket = None self.io_executor = ThreadPoolExecutor(max_workers=4) self._init_network() @@ -105,13 +108,21 @@ def start_receiver(self): self.logger.error(f"start_receiver: Receiver error: {e}, {str(traceback.format_exc())}") time.sleep(1) - def _get_push_socket(self, addr): - """获取或创建 DEALER socket""" + def _get_push_socket(self, addr) -> Tuple[zmq.Socket, threading.Lock]: + """ + 获取或创建 DEALER socket 及其发送锁。 + + Returns: + Tuple[zmq.Socket, threading.Lock]: 目标地址对应的 socket 和保护 multipart 发送的锁。 + """ - if addr in self.push_sockets: - sock = self.push_sockets[addr] - if not sock.closed: - return sock + with self._push_sockets_meta_lock: + if addr in self.push_sockets: + sock = self.push_sockets[addr] + if not sock.closed: + return sock, self._push_socket_locks[addr] + del self.push_sockets[addr] + self._push_socket_locks.pop(addr, None) try: self.logger.info(f"_get_push_socket: Establishing new connection to {addr}") @@ -129,8 +140,18 @@ def _get_push_socket(self, addr): sock.connect(f"tcp://{addr}") - self.push_sockets[addr] = sock - return sock + with self._push_sockets_meta_lock: + if addr in self.push_sockets: + existing_sock = self.push_sockets[addr] + if not existing_sock.closed: + sock.close() + return existing_sock, self._push_socket_locks[addr] + del self.push_sockets[addr] + self._push_socket_locks.pop(addr, None) + + self.push_sockets[addr] = sock + self._push_socket_locks[addr] = threading.Lock() + return sock, self._push_socket_locks[addr] except zmq.ZMQError as e: self.logger.error(f"_get_push_socket: Connection to {addr} failed: {e}") @@ -144,8 +165,11 @@ def _send_message(self, addr, msg_type: str, payload): message = self._serialize_message(msg_type, payload) try: self.logger.info(f"_send_message: msg_type={msg_type} addr={addr}") - sock = self._get_push_socket(addr) - sock.send_multipart(message) + sock, lock = self._get_push_socket(addr) + with lock: + if sock.closed: + raise ConnectionError(f"Connection to {addr} is closed") + sock.send_multipart(message) self.logger.info(f"Sent {msg_type} to {addr}") @@ -164,9 +188,19 @@ def _close_connection(self, addr): """ Close the connection to the specified address. """ - if addr in self.push_sockets: - self.push_sockets[addr].close() - del self.push_sockets[addr] + sock = None + lock = None + with self._push_sockets_meta_lock: + if addr in self.push_sockets: + sock = self.push_sockets.pop(addr) + lock = self._push_socket_locks.pop(addr, None) + + if sock is not None: + if lock is not None: + with lock: + sock.close() + else: + sock.close() def send_splitwise_tasks(self, tasks: List[Request], current_id): """ @@ -202,6 +236,27 @@ def send_first_token(self, prefill_msg, tasks_list): ) self._send_message(addr, "decode", tasks_list) + def send_drop_signal(self, request_id: str, disaggregate_info: dict): + """ + Notify the decode side that this prefill request has been dropped + (e.g. paused gate rejected it on P). The decode side should recycle + its scheduler entry for this request_id, otherwise it would sit + there forever as a ghost and pause/abort drain would hang. + """ + if not disaggregate_info: + return + decode_ip = disaggregate_info.get("decode_ip") + decode_port = disaggregate_info.get("decode_connector_port") + if not decode_ip or not decode_port: + self.logger.warning( + f"send_drop_signal: missing decode_ip/decode_connector_port in " + f"disaggregate_info for {request_id}; skip" + ) + return + addr = f"{decode_ip}:{decode_port}" + self.logger.info(f"send_drop_signal: addr={addr}, request_id={request_id}") + self._send_message(addr, "drop", {"request_id": request_id}) + def check_decode_allocated(self, task): """Check whether the requests have been allocated resources in decode.""" self.logger.debug(f"check_decode_allocated: {task.request_id}") @@ -212,7 +267,7 @@ def check_decode_allocated(self, task): return True, "" while self.current_request_ids[task.request_id] == "init": - time.sleep(0.001) + time.sleep(0.005) if time.time() - start_time > envs.FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS: del self.current_request_ids[task.request_id] return False, "prefill waits for decode resource timeout" @@ -277,23 +332,15 @@ def send_cache_info_to_prefill(self, tasks: List[Request]): "request_id": tasks[i].request_id, "error_msg": tasks[i].get("error_msg"), } - if ( - envs.ENABLE_V1_KVCACHE_SCHEDULER - and tasks[i].request_id in self.resource_manager.waiting_abort_req_id_set - ): - addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" - if addr not in cache_info: - cache_info[addr] = [] - cache_info[addr].append(info) else: - addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" info = { "request_id": tasks[i].request_id, "dest_block_ids": dsg_info["block_tables"], } - if addr not in cache_info: - cache_info[addr] = [] - cache_info[addr].append(info) + addr = f"{dsg_info['prefill_ip']}:" + f"{dsg_info['prefill_connector_port']}" + if addr not in cache_info: + cache_info[addr] = [] + cache_info[addr].append(info) self.logger.debug(f"send cache info to prefill, {cache_info}") if len(cache_info): @@ -356,6 +403,8 @@ def _process_message(self, frames: List[bytes]): self._handle_prefill(payload) elif msg_type == "decode": self._handle_decode(payload) + elif msg_type == "drop": + self._handle_drop(payload) elif msg_type == "cache_sync": for task in payload: self.logger.info(f"_process_message: cache_sync task: {task}") @@ -363,7 +412,7 @@ def _process_message(self, frames: List[bytes]): self.current_request_ids[task["request_id"]] = current_status if self.enable_decode_cache_task: del self.current_request_ids[task["request_id"]] - if current_status == "finished": + if current_status == "finished" and not envs.FD_PD_TRANSFER_VIA_STORAGE: self.engine_worker_queue.put_cache_info(payload) except Exception as e: @@ -386,3 +435,15 @@ def _handle_decode(self, payload): for task in payload: tasks.append(RequestOutput.from_dict(task)) self.engine_worker_queue.put_disaggregated_tasks(("decode", tasks)) + + def _handle_drop(self, payload): + """ + Handle drop signal from prefill: forward to engine worker queue so the + decode engine main loop can recycle the corresponding scheduler entry. + """ + request_id = payload.get("request_id") if isinstance(payload, dict) else None + if not request_id: + self.logger.warning(f"_handle_drop: invalid payload {payload}") + return + self.logger.info(f"_handle_drop: request_id={request_id}") + self.engine_worker_queue.put_disaggregated_tasks(("decode_drop", [request_id])) diff --git a/fastdeploy/trace/constants.py b/fastdeploy/trace/constants.py index eaf54d68085..ff481af248a 100644 --- a/fastdeploy/trace/constants.py +++ b/fastdeploy/trace/constants.py @@ -26,17 +26,35 @@ class LoggingEventName(Enum): REQUEST_QUEUE_START = "REQUEST_QUEUE_START" REQUEST_QUEUE_END = "REQUEST_QUEUE_END" RESOURCE_ALLOCATE_START = "RESOURCE_ALLOCATE_START" + PREPARE_PREFIX_CACHE_START = "PREPARE_PREFIX_CACHE_START" + PREPARE_PREFIX_CACHE_END = "PREPARE_PREFIX_CACHE_END" RESOURCE_ALLOCATE_END = "RESOURCE_ALLOCATE_END" REQUEST_SCHEDULE_END = "REQUEST_SCHEDULE_END" INFERENCE_START = "INFERENCE_START" FIRST_TOKEN_GENERATED = "FIRST_TOKEN_GENERATED" DECODE_START = "DECODE_START" INFERENCE_END = "INFERENCE_END" + WRITE_CACHE_TO_STORAGE_START = "WRITE_CACHE_TO_STORAGE_START" + WRITE_CACHE_TO_STORAGE_END = "WRITE_CACHE_TO_STORAGE_END" POSTPROCESSING_START = "POSTPROCESSING_START" POSTPROCESSING_END = "POSTPROCESSING_END" PREEMPTED = "PREEMPTED" RESCHEDULED_INFERENCE_START = "RESCHEDULED_INFERENCE_START" + # For Prefill Instance + ASK_DECODE_RESOURCE_START = "ASK_DECODE_RESOURCE_START" + ASK_DECODE_RESOURCE_END = "ASK_DECODE_RESOURCE_END" + CHECK_CACHE_TRANSFER_START = "CHECK_CACHE_TRANSFER_START" + CHECK_CACHE_TRANSFER_END = "CHECK_CACHE_TRANSFER_END" + PREFILL_INFERENCE_END = "PREFILL_INFERENCE_END" + + # For Decode Instance + DECODE_PROCESS_PREALLOCATE_REQUEST_START = "DECODE_PROCESS_PREALLOCATE_REQUEST_START" + DECODE_PROCESS_PREALLOCAT_REQUEST_END = "DECODE_PROCESS_PREALLOCAT_REQUEST_END" + DECODE_PROCESS_PREFILLED_REQUEST_START = "DECODE_PROCESS_PREFILLED_REQUEST_START" + DECODE_PROCESS_PREFILLED_REQUEST_END = "DECODE_PROCESS_PREFILLED_REQUEST_END" + DECODE_INFERENCE_END = "DECODE_INFERENCE_END" + class StageName(Enum): """ @@ -57,6 +75,8 @@ class StageName(Enum): LoggingEventName.REQUEST_QUEUE_START: StageName.SCHEDULE, LoggingEventName.REQUEST_QUEUE_END: StageName.SCHEDULE, LoggingEventName.RESOURCE_ALLOCATE_START: StageName.SCHEDULE, + LoggingEventName.PREPARE_PREFIX_CACHE_START: StageName.SCHEDULE, + LoggingEventName.PREPARE_PREFIX_CACHE_END: StageName.SCHEDULE, LoggingEventName.RESOURCE_ALLOCATE_END: StageName.SCHEDULE, LoggingEventName.REQUEST_SCHEDULE_END: StageName.SCHEDULE, LoggingEventName.INFERENCE_START: StageName.PREFILL, @@ -65,6 +85,18 @@ class StageName(Enum): LoggingEventName.PREEMPTED: StageName.DECODE, LoggingEventName.RESCHEDULED_INFERENCE_START: StageName.DECODE, LoggingEventName.INFERENCE_END: StageName.DECODE, + LoggingEventName.WRITE_CACHE_TO_STORAGE_START: StageName.POSTPROCESSING, + LoggingEventName.WRITE_CACHE_TO_STORAGE_END: StageName.POSTPROCESSING, LoggingEventName.POSTPROCESSING_START: StageName.POSTPROCESSING, LoggingEventName.POSTPROCESSING_END: StageName.POSTPROCESSING, + LoggingEventName.ASK_DECODE_RESOURCE_START: StageName.SCHEDULE, + LoggingEventName.ASK_DECODE_RESOURCE_END: StageName.SCHEDULE, + LoggingEventName.CHECK_CACHE_TRANSFER_START: StageName.POSTPROCESSING, + LoggingEventName.CHECK_CACHE_TRANSFER_END: StageName.POSTPROCESSING, + LoggingEventName.PREFILL_INFERENCE_END: StageName.PREFILL, + LoggingEventName.DECODE_PROCESS_PREALLOCATE_REQUEST_START: StageName.DECODE, + LoggingEventName.DECODE_PROCESS_PREALLOCAT_REQUEST_END: StageName.DECODE, + LoggingEventName.DECODE_PROCESS_PREFILLED_REQUEST_START: StageName.DECODE, + LoggingEventName.DECODE_PROCESS_PREFILLED_REQUEST_END: StageName.DECODE, + LoggingEventName.DECODE_INFERENCE_END: StageName.DECODE, } diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 0a591dc2777..965fca6d96c 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -38,7 +38,7 @@ from importlib.metadata import PackageNotFoundError, distribution from logging.handlers import BaseRotatingHandler from pathlib import Path -from typing import Any, Literal, TypeVar, Union +from typing import Any, Dict, Literal, TypeVar, Union import numpy as np import paddle @@ -626,6 +626,12 @@ def is_port_available(host, port): import errno import socket + # If FD_ENGINE_TASK_QUEUE_WITH_SHM is enabled, then check the file socket is available + if envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + socket_path = f"/dev/shm/fd_task_queue_{port}.sock" + if not is_file_socket_available(socket_path): + return False + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -637,6 +643,35 @@ def is_port_available(host, port): return True +def is_file_socket_available(socket_path): + """ + Check the Unix domain socket (file socket) is available. + + Args: + socket_path: Path to the socket file, e.g. /dev/shm/fd_task_queue_8000.sock + + Returns: + True if the socket is available (not in use), False otherwise. + """ + import errno + import os + import socket + + if not os.path.exists(socket_path): + return True + + # File exists, try to connect to see if someone is listening + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + try: + s.connect(socket_path) + return False + except OSError as e: + if e.errno in (errno.ECONNREFUSED, errno.ENOENT): + # Stale socket file: exists but nobody is listening + return True + return False + + def find_free_ports( port_range: tuple[int, int] = (8000, 65535), num_ports: int = 1, @@ -1044,10 +1079,14 @@ def status(self) -> dict: } -def parse_quantization(value: str): +def parse_quantization(value: Union[Dict, str]) -> Dict: """ Parse a JSON string into a dictionary. """ + if isinstance(value, dict): + return value + if value is None: + value = "null" try: return json.loads(value) except ValueError: diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 44a8c5f3578..284cee8843d 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -62,7 +62,7 @@ def __init__( local_rank: int, ): super().__init__(fd_config=fd_config, device=device) - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index c0e689735d4..07cb4dfb67b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -27,7 +27,7 @@ from paddle import nn from paddleformers.utils.log import logger -from fastdeploy.config import FDConfig +from fastdeploy.config import PREEMPTED_TOKEN_ID, FDConfig from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.request import ImagePosition, Request, RequestType from fastdeploy.model_executor.graph_optimization.utils import ( @@ -45,6 +45,12 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) +from fastdeploy.model_executor.layers.attention.dsa_attention_backend import ( + DSAAttentionBackend, +) +from fastdeploy.model_executor.layers.attention.mla_attention_backend import ( + MLAAttentionBackend, +) from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( RoutingReplayManager, ) @@ -65,11 +71,13 @@ ) share_external_data = None + get_position_ids_and_slot_mapping = None elif current_platform.is_dcu(): from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx recover_decode_task = None share_external_data = None + get_position_ids_and_slot_mapping = None else: from fastdeploy.model_executor.ops.gpu import ( recover_decode_task, @@ -78,6 +86,7 @@ speculate_schedule_cache, set_data_ipc, unset_data_ipc, + get_position_ids_and_slot_mapping, ) import zmq @@ -85,7 +94,7 @@ from fastdeploy import envs from fastdeploy.engine.tasks import PoolingTask from fastdeploy.input.ernie4_5_vl_processor import DataProcessor -from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient +from fastdeploy.inter_communicator import IPCSignal, KVCacheStatus, ZmqIpcClient from fastdeploy.logger.deterministic_logger import DeterministicLogger from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata @@ -97,7 +106,7 @@ pre_process, rebuild_padding, save_output_normal, - save_output_specualate, + save_output_speculate, ) from fastdeploy.output.pooler import PoolerOutput from fastdeploy.worker.model_runner_base import ( @@ -105,7 +114,12 @@ DistributedStatus, ModelRunnerBase, ) -from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput +from fastdeploy.worker.output import ( + LogprobsTensors, + ModelOutputData, + ModelRunnerOutput, + SamplerOutput, +) class GPUModelRunner(ModelRunnerBase): @@ -119,13 +133,14 @@ def __init__( ): super().__init__(fd_config=fd_config, device=device) self.MAX_INFER_SEED = 9223372036854775806 - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id self.spec_method = self.fd_config.speculative_config.method self.speculative_decoding = self.spec_method is not None self.enable_logprob = fd_config.model_config.enable_logprob + self.enable_keep_sampling_mask = fd_config.model_config.enable_keep_sampling_mask self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" self.ori_vocab_size = self.fd_config.model_config.ori_vocab_size @@ -136,8 +151,8 @@ def __init__( if fd_config.model_config.max_logprobs == -1 else fd_config.model_config.max_logprobs ) - self.temp_scaled_logprobs = True - self.top_p_normalized_logprobs = True + self.temp_scaled_logprobs = False + self.top_p_normalized_logprobs = False self.prompt_logprobs_reqs: dict[str, Request] = {} self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} self.forward_batch_reqs_list: list[Request] = [None for _ in range(self.scheduler_config.max_num_seqs)] @@ -236,6 +251,27 @@ def __init__( # Rollout routing replay config self.routing_replay_manager = None + # ZMQ side-channel for sampling_mask in non-FD_USE_GET_SAVE_OUTPUT_V1 path + self.sampling_mask_zmq_client = None + if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.enable_keep_sampling_mask: + rank_id = self.parallel_config.local_data_parallel_id + port = self.parallel_config.engine_worker_queue_port[rank_id] + self.sampling_mask_zmq_client = ZmqIpcClient( + name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PUSH + ) + self.sampling_mask_zmq_client.connect() + logger.info(f"create send zmq sampling_mask_output_rank_{rank_id}_{port}") + + self.sampling_mask_async_queue = None + if self.sampling_mask_zmq_client is not None: + self.sampling_mask_async_queue = queue.Queue() + self._sampling_mask_send_thread = Thread( + target=self._async_sampling_mask_send_loop, + daemon=True, + name="WorkerAsyncSamplingMaskSend", + ) + self._sampling_mask_send_thread.start() + self.zmq_client = None self.async_output_queue = None if envs.FD_USE_GET_SAVE_OUTPUT_V1: @@ -286,6 +322,27 @@ def _async_output_busy_loop(self): except Exception as e: logger.exception("Exception in async output loop: %s", e) + def _async_sampling_mask_send_loop(self): + """Background thread: serialize and send sampling_mask over ZMQ.""" + while True: + try: + mask_list, accept_nums = self.sampling_mask_async_queue.get() + if accept_nums is None: + # Normal (non-speculative) path + mask_dict = {i: arr.tolist() for i, arr in enumerate(mask_list)} + else: + # Speculative path: group by accept_num + mask_dict = {} + offset = 0 + for i, n in enumerate(accept_nums): + n = int(n) + if n > 0: + mask_dict[i] = [arr.tolist() for arr in mask_list[offset : offset + n]] + offset += n + self.sampling_mask_zmq_client.send_pyobj(mask_dict) + except Exception as e: + logger.exception("Exception in async sampling_mask send loop: %s", e) + def exist_prefill(self): """ check whether prefill stage exist @@ -345,9 +402,7 @@ def _predict_next_launch_token_num(self) -> int: is_block_step_cpu = self.share_inputs["is_block_step_cpu"].numpy() next_real_bsz = (seq_lens_this_time_cpu > 0).sum().item() + (is_block_step_cpu > 0).sum().item() token_num_one_step = (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1 - next_launch_token_num = ( - seq_lens_this_time_cpu.sum().item() + is_block_step_cpu.sum().item() * token_num_one_step - ) + next_launch_token_num = next_real_bsz * token_num_one_step return next_launch_token_num, next_real_bsz def only_prefill(self): @@ -819,9 +874,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = input_ids = prompt_token_ids + request.output_token_ids prompt_len = len(prompt_token_ids) # prompt_tokens - self.share_inputs["token_ids_all"][idx : idx + 1, :prompt_len] = np.array( - prompt_token_ids, dtype="int64" - ) + async_set_value(self.share_inputs["token_ids_all"][idx : idx + 1, :prompt_len], prompt_token_ids) # generated_token_ids fill -1 self.share_inputs["token_ids_all"][idx : idx + 1, prompt_len:] = -1 @@ -831,33 +884,39 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.deterministic_logger.log_prefill_input( request.request_id, idx, prefill_start_index, prefill_end_index, input_ids ) - logger.debug( f"Handle prefill request {request} at idx {idx}, " f"{prefill_start_index=}, {prefill_end_index=}, " f"need_prefilled_token_num={len(input_ids)}" f"prompt_len={prompt_len}" ) - self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array( - input_ids[prefill_start_index:prefill_end_index] + async_set_value( + self.share_inputs["input_ids"][idx : idx + 1, :length], + input_ids[prefill_start_index:prefill_end_index], ) encoder_block_num = len(request.block_tables) - self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num - self.share_inputs["block_tables"][idx : idx + 1, :] = -1 - self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( - request.block_tables, dtype="int32" + async_set_value(self.share_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) + + async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) + + async_set_value( + self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables ) - self.share_inputs["stop_flags"][idx : idx + 1] = False - self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index - self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = length - self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + + async_set_value(self.share_inputs["stop_flags"][idx : idx + 1], False) + + async_set_value(self.share_inputs["seq_lens_decoder"][idx : idx + 1], prefill_start_index) + async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], length) + async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], length) self.exist_prefill_flag = True - self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0 - self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids) - self.share_inputs["is_block_step"][idx : idx + 1] = False + async_set_value(self.share_inputs["step_seq_lens_decoder"][idx : idx + 1], 0) + async_set_value(self.share_inputs["prompt_lens"][idx : idx + 1], len(input_ids)) + + async_set_value(self.share_inputs["is_block_step"][idx : idx + 1], False) self.share_inputs["is_chunk_step"][idx : idx + 1] = prefill_end_index < len(input_ids) - self.share_inputs["step_idx"][idx : idx + 1] = ( - len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + async_set_value( + self.share_inputs["step_idx"][idx : idx + 1], + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0, ) # pooling model request.sampling_params is None if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: @@ -871,29 +930,40 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = prompt_token_ids = request.prompt_token_ids self.proposer.start_request(idx, request.request_id, prompt_token_ids) - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - # 1.prefix task(need regist) 2. chunkend task(not need regist) - self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) - if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token - self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + # TODO: delete useless operation like this + async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], 0) self.exist_prefill_flag = False - self._cached_launch_token_num = -1 + if self._cached_launch_token_num != -1: + token_num_one_step = ( + (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1 + ) + self._cached_launch_token_num += token_num_one_step + self._cached_real_bsz += 1 if self.speculative_decoding: - # D speculate decode, seq_lens_this_time = length + 1 - self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + 1 - self.share_inputs["draft_tokens"][idx : idx + 1, 0 : length + 1] = paddle.to_tensor( - request.draft_token_ids[0 : length + 1], - dtype="int64", + # D first decode step, [Target first token, MTP first draft token] + # MTP in P only generate one draft token in any num_model_step config + draft_tokens_to_write = request.draft_token_ids[0:2] + if len(draft_tokens_to_write) != 2: + raise ValueError( + "Expected at least 2 draft tokens for speculative suffix decode, " + f"but got {len(draft_tokens_to_write)} for request {request.request_id}." + ) + async_set_value( + self.share_inputs["draft_tokens"][idx : idx + 1, 0:2], + draft_tokens_to_write, ) + async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 2) + logger.debug( + f"insert request {request.request_id} idx: {idx} suffix tokens {request.draft_token_ids}" + ) elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") encoder_block_num = len(request.block_tables) - self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num - self.share_inputs["block_tables"][idx : idx + 1, :] = -1 + async_set_value(self.share_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num) + async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) if current_platform.is_cuda(): async_set_value( self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num], request.block_tables @@ -902,6 +972,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) + # CPU Tensor self.share_inputs["preempted_idx"][idx : idx + 1, :] = 0 continue else: # preempted task @@ -910,70 +981,77 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = elif request.task_type.value == RequestType.ABORT.value: logger.info(f"Handle abort request {request} at idx {idx}") self.share_inputs["preempted_idx"][idx : idx + 1, :] = 1 - self.share_inputs["block_tables"][idx : idx + 1, :] = -1 - self.share_inputs["stop_flags"][idx : idx + 1] = True - self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1] = 0 - self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 - self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 - self.share_inputs["is_block_step"][idx : idx + 1] = False + async_set_value(self.share_inputs["block_tables"][idx : idx + 1, :], -1) + async_set_value(self.share_inputs["stop_flags"][idx : idx + 1], True) + async_set_value(self.share_inputs["seq_lens_this_time_buffer"][idx : idx + 1], 0) + async_set_value(self.share_inputs["seq_lens_decoder"][idx : idx + 1], 0) + async_set_value(self.share_inputs["seq_lens_encoder"][idx : idx + 1], 0) + async_set_value(self.share_inputs["is_block_step"][idx : idx + 1], False) self.prompt_logprobs_reqs.pop(request.request_id, None) self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_request(batch_id=idx) - continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens - self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) - - self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) - self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) - self.share_inputs["top_k_list"][idx] = request.get("top_k", 0) - self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) + self.share_inputs["top_p_list"][idx] = request.get("top_p", 0.7) self.share_inputs["min_p_list"][idx] = request.get("min_p", 0.0) - self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) - self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) - self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) - self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) - self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False) - self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get( - "top_p_normalized_logprobs", False + self.share_inputs["top_k_list"][idx] = request.get("top_k", 0) + async_set_value(self.share_inputs["eos_token_id"][:], request.eos_token_ids) + async_set_value(self.share_inputs["top_p"][idx : idx + 1], request.get("top_p", 0.7)) + async_set_value(self.share_inputs["top_k"][idx : idx + 1], request.get("top_k", 0)) + async_set_value(self.share_inputs["min_p"][idx : idx + 1], request.get("min_p", 0.0)) + async_set_value(self.share_inputs["temperature"][idx : idx + 1], request.get("temperature", 0.95)) + async_set_value(self.share_inputs["penalty_score"][idx : idx + 1], request.get("repetition_penalty", 1.0)) + async_set_value(self.share_inputs["frequency_score"][idx : idx + 1], request.get("frequency_penalty", 0.0)) + async_set_value(self.share_inputs["presence_score"][idx : idx + 1], request.get("presence_penalty", 0.0)) + async_set_value( + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1], request.get("temp_scaled_logprobs", False) ) - self.share_inputs["generated_modality"][idx : idx + 1] = request.get("generated_modality", 0) - - self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) - self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( - "max_tokens", self.model_config.max_model_len + async_set_value( + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1], + request.get("top_p_normalized_logprobs", False), + ) + async_set_value( + self.share_inputs["generated_modality"][idx : idx + 1], request.get("generated_modality", 0) + ) + async_set_value(self.share_inputs["min_dec_len"][idx : idx + 1], request.get("min_tokens", 1)) + async_set_value( + self.share_inputs["max_dec_len"][idx : idx + 1], + request.get("max_tokens", self.model_config.max_model_len), ) if request.get("seed") is not None: - self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed") + async_set_value(self.share_inputs["infer_seed"][idx : idx + 1], request.get("seed")) if request.get("bad_words_token_ids") is not None and len(request.get("bad_words_token_ids")) > 0: bad_words_len = len(request.get("bad_words_token_ids")) - self.share_inputs["bad_tokens_len"][idx] = bad_words_len - self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len] = np.array( - request.get("bad_words_token_ids"), dtype="int64" + async_set_value(self.share_inputs["bad_tokens_len"][idx : idx + 1], bad_words_len) + async_set_value( + self.share_inputs["bad_tokens"][idx : idx + 1, :bad_words_len], request.get("bad_words_token_ids") ) else: - self.share_inputs["bad_tokens_len"][idx] = 1 - self.share_inputs["bad_tokens"][idx : idx + 1, :] = np.array([-1], dtype="int64") + async_set_value(self.share_inputs["bad_tokens_len"][idx : idx + 1], 1) + async_set_value(self.share_inputs["bad_tokens"][idx : idx + 1, :], -1) if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None: stop_seqs_num = len(request.get("stop_seqs_len")) for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num): request.sampling_params.stop_seqs_len.append(0) - self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = np.array( - request.sampling_params.stop_seqs_len, dtype="int32" + async_set_value( + self.share_inputs["stop_seqs_len"][idx : idx + 1, :], request.sampling_params.stop_seqs_len ) - self.share_inputs["stop_seqs"][ - idx : idx + 1, :stop_seqs_num, : len(request.get("stop_token_ids")[0]) - ] = np.array(request.get("stop_token_ids"), dtype="int64") + # Pad each stop sequence to stop_seqs_max_len, then fill remaining rows + # and write the whole block at once to avoid partial slicing on the + # third dimension, which may cause async_set_value stride issues on + # non-contiguous memory. + stop_token_ids = request.get("stop_token_ids") + max_len = self.model_config.stop_seqs_max_len + padded = [seq + [-1] * (max_len - len(seq)) for seq in stop_token_ids] + padded.extend([[-1] * max_len] * (self.model_config.max_stop_seqs_num - stop_seqs_num)) + async_set_value(self.share_inputs["stop_seqs"][idx : idx + 1, :, :], padded) else: - self.share_inputs["stop_seqs_len"][idx : idx + 1, :] = 0 + async_set_value(self.share_inputs["stop_seqs_len"][idx : idx + 1, :], 0) self.pooling_params = batch_pooling_params # For logits processors @@ -982,7 +1060,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) self._process_mm_features(req_dicts) - if len(rope_3d_position_ids["position_ids_idx"]) > 0: + + if len(rope_3d_position_ids["position_ids_idx"]) > 0 and self.enable_mm: packed_position_ids = paddle.to_tensor( np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64" ) @@ -1120,10 +1199,12 @@ def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None: """Prepare the model inputs""" + if self.enable_mm and self.share_inputs["image_features_list"] is not None: tensor_feats = [t for t in self.share_inputs["image_features_list"] if isinstance(t, paddle.Tensor)] if tensor_feats: self.share_inputs["image_features"] = paddle.concat(tensor_feats, axis=0) + recover_decode_task( self.share_inputs["stop_flags"], self.share_inputs["seq_lens_this_time"], @@ -1143,20 +1224,16 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p for req in self.forward_batch_reqs_list if req is not None and req.sampling_params is not None and req.sampling_params.logprobs is not None ] - if len(logprobs_reqs): - self.max_logprobs = ( - max( - [ - self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs - for req in logprobs_reqs - ] - ) - if not self.speculative_decoding - else 20 - ) - self.temp_scaled_logprobs = any(req.sampling_params.temp_scaled_logprobs for req in logprobs_reqs) - self.top_p_normalized_logprobs = any( - req.sampling_params.top_p_normalized_logprobs for req in logprobs_reqs + self.temp_scaled_logprobs = any(req.sampling_params.temp_scaled_logprobs for req in logprobs_reqs) + self.top_p_normalized_logprobs = any( + req.sampling_params.top_p_normalized_logprobs and req.sampling_params.top_p != 1.0 for req in logprobs_reqs + ) + if logprobs_reqs: + self.max_logprobs = max( + [ + self.ori_vocab_size if req.sampling_params.logprobs < 0 else req.sampling_params.logprobs + for req in logprobs_reqs + ] ) elif self.enable_logprob: self.max_logprobs = None if not self.speculative_decoding else 0 @@ -1209,6 +1286,7 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], + top_p_list=self.share_inputs["top_p_list"], top_k=self.share_inputs["top_k"], top_k_list=self.share_inputs["top_k_list"], min_p=self.share_inputs["min_p"], @@ -1233,9 +1311,49 @@ def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_p top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], logits_processors=self.share_inputs["logits_processors"], share_inputs=self.share_inputs, + keep_sampling_mask=self.enable_keep_sampling_mask, ) return token_num, token_num_event + def _compute_position_ids_and_slot_mapping(self, total_token_num) -> None: + """Compute position_ids and slot_mapping for KV cache addressing. + This is a general computation based on sequence length info and block tables, + applicable to all models that need per-token KV cache physical slot addresses. + Results are stored in self.forward_meta. + """ + # NOTE(zhushengguang): Only support MLAAttentionBackend and DSAAttentionBackend currently. + # Also needed when R3 (Routing Replay) is enabled for slot_mapping_buffer computation. + needs_slot_mapping = isinstance(self.attn_backends[0], (MLAAttentionBackend, DSAAttentionBackend)) + needs_slot_mapping = (self.routing_replay_manager is not None) or needs_slot_mapping + if not needs_slot_mapping: + return + # Directly write to existing buffers (no memory allocation or copy needed) + position_ids_buffer = self.share_inputs["position_ids_buffer"][:total_token_num] + slot_mapping_buffer = self.share_inputs["slot_mapping_buffer"][:total_token_num] + get_position_ids_and_slot_mapping( + self.forward_meta.seq_lens_encoder, + self.forward_meta.seq_lens_decoder, + self.forward_meta.seq_lens_this_time, + self.forward_meta.batch_id_per_token, + self.forward_meta.block_tables, + position_ids_buffer, + slot_mapping_buffer, + self.cache_config.block_size, + ) + # Store views in forward_meta + self.forward_meta.position_ids = position_ids_buffer + self.forward_meta.slot_mapping = slot_mapping_buffer + + # Debug: print all tokens' position_ids and slot_mapping in R3 debug mode + if self.routing_replay_manager is not None and self.routing_replay_manager.debug_mode: + logger.info(f"[R3 Debug] token mapping: num_tokens={total_token_num}") + logger.info(" token | position_id | slot ") + logger.info(" " + "-" * 30) + for i in range(total_token_num): + logger.info( + f" {i:4d} | {int(self.forward_meta.position_ids[i]):8d} | {int(self.forward_meta.slot_mapping[i]):7d}" + ) + def _process_reorder(self) -> None: if self.attn_backends and getattr(self.attn_backends[0], "enable_ids_reorder", False): self.share_inputs.enable_pd_reorder = True @@ -1279,11 +1397,10 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta - routing_replay_table = None - if self.routing_replay_manager is not None: - routing_replay_table = self.routing_replay_manager.get_routing_table() - num_running_requests = self.share_inputs["seq_lens_this_time"].shape[0] + device_routing_buffer = None + if self.routing_replay_manager is not None: + device_routing_buffer = self.routing_replay_manager.get_device_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1310,9 +1427,18 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - routing_replay_table=routing_replay_table, + device_routing_buffer=device_routing_buffer, ) + # Decode attention split ops buffers (assigned after construction due to ForwardMeta __getattr__) + if "decode_block_indices" in self.share_inputs: + self.forward_meta.decode_block_indices = self.share_inputs["decode_block_indices"] + self.forward_meta.decode_num_blocks = self.share_inputs["decode_num_blocks"] + self.forward_meta.decode_chunk_size = self.share_inputs["decode_chunk_size"] + self.forward_meta.decode_tmp_workspace = self.share_inputs["decode_tmp_workspace"] + self.forward_meta.decode_tmp_m = self.share_inputs["decode_tmp_m"] + self.forward_meta.decode_tmp_d = self.share_inputs["decode_tmp_d"] + dist_status = self.collect_distributed_status() if_only_decode = dist_status.if_only_decode @@ -1404,10 +1530,12 @@ def initialize_kv_cache(self, profile: bool = False) -> None: # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. - # 2. If no need to profile, create kv cache if cache managers do not exist. + # 2. If no need to profile, create kv cache unless kvcache_storage_backend or + # p/d disaggregation is enabled. Note: CPU cache (num_cpu_blocks > 0) does NOT + # prevent GPU runner from creating GPU cache tensors; cache transfer manager + # handles CPU<->GPU swap on top of the GPU tensors created here. create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) @@ -1539,6 +1667,8 @@ def _initialize_attn_backend(self) -> None: num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, + head_dim=head_dim, + dtype=self.model_config.dtype, ) self.share_inputs.update(res_buffer) @@ -1697,6 +1827,7 @@ def _dummy_sampler_run( self.increment_value, accept_all_drafts, reject_all_drafts, + real_bsz=batch_size, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -1820,10 +1951,13 @@ def _dummy_run( while True: # 1. Initialize forward meta and attention meta data - self._prepare_inputs(is_dummy_or_profile_run=True) + token_num, _ = self._prepare_inputs(is_dummy_or_profile_run=True) # 2. Padding inputs for cuda graph self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph self.padding_cudagraph_inputs() + # Compute position_ids and slot_mapping + + self._compute_position_ids_and_slot_mapping(total_token_num=token_num) model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] @@ -1895,8 +2029,7 @@ def capture_model(self) -> None: logger.info( f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" ) - elif self.speculative_decoding and self.spec_method == SpecMethod.MTP: - # Capture Target Model without bsz 1 + elif self.speculative_decoding and self.spec_method in [SpecMethod.MTP, SpecMethod.SUFFIX]: for capture_size in sorted(capture_sizes, reverse=True): expected_decode_len = (self.speculative_config.num_speculative_tokens + 1) * 2 self._dummy_run( @@ -2035,6 +2168,95 @@ def _execute_empty_mtp_input(self, forward_meta) -> None: for _ in range(self.fd_config.speculative_config.num_model_steps): self.proposer.model.empty_input_forward(forward_meta) + def _make_preempted_batch_output(self): + """Build a minimal batch-shaped control output for preempted slots. + + This is used when the current step contains only preempted/aborted + requests and therefore produces no normal model tokens. The helper + fabricates a lightweight batch output so the existing save_output path + can still return PREEMPTED_TOKEN_ID for the affected slots. + """ + preempted_indices = paddle.nonzero(self.share_inputs["preempted_idx"] == 1) + bsz = int(preempted_indices[-1][0].item()) + 1 + + fake_sampled_token_ids = paddle.where( + self.share_inputs["preempted_idx"][:bsz] == 1, + PREEMPTED_TOKEN_ID, + -1, + ).astype("int64") + sampled_token_ids = self.share_inputs["sampled_token_ids"].cpu() + sampled_token_ids[:bsz].copy_(fake_sampled_token_ids, True) + self.share_inputs["sampled_token_ids"].copy_(sampled_token_ids, True) + + fake_logprobs_tensors = None + if self.enable_logprob: + fake_logprobs_tensors = LogprobsTensors( + logprob_token_ids=paddle.zeros([bsz, 1], dtype="int64", device="cpu"), + logprobs=paddle.zeros([bsz, 1], dtype="float32", device="cpu"), + selected_token_ranks=paddle.zeros([bsz], dtype="int64", device="cpu"), + ) + + if self.speculative_decoding: + self.share_inputs["accept_tokens"][:bsz].fill_(0) + self.share_inputs["accept_num"][:bsz].fill_(0) + self.share_inputs["accept_tokens_cpu"].copy_(self.share_inputs["accept_tokens"], True) + self.share_inputs["accept_num_cpu"].copy_(self.share_inputs["accept_num"], True) + self.share_inputs["seq_lens_decoder_cpu"].copy_(self.share_inputs["seq_lens_decoder"], True) + self.share_inputs["prompt_lens_cpu"].copy_(self.share_inputs["prompt_lens"], True) + sampler_output = SamplerOutput( + sampled_token_ids=fake_sampled_token_ids, + logprobs_tensors=fake_logprobs_tensors, + token_num_per_batch=(self.share_inputs["accept_num_cpu"][:bsz] if self.enable_logprob else None), + cu_batch_token_offset=( + paddle.zeros([bsz + 1], dtype="int32", device="cpu") if self.enable_logprob else None + ), + ) + else: + sampler_output = SamplerOutput( + sampled_token_ids=fake_sampled_token_ids, + logprobs_tensors=fake_logprobs_tensors, + ) + + index_to_batch_id = { + i: self.share_inputs["index_to_batch_id"][i] + for i in range(bsz) + if i in self.share_inputs["index_to_batch_id"] + } + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + not_need_stop_device=self.share_inputs["not_need_stop_device"], + input_ids=self.share_inputs["input_ids"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=None, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.parallel_config.tensor_parallel_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + token_ids_all=self.share_inputs["token_ids_all"], + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], + min_tokens=self.share_inputs["min_dec_len"], + prompt_lens=self.share_inputs["prompt_lens"], + mask_rollback=self.share_inputs["mask_rollback"], + prompt_logprobs_list=None, + index_to_batch_id=index_to_batch_id, + enable_pd_reorder=getattr(self.share_inputs, "enable_pd_reorder", False), + ) + return model_output_data, sampler_output + def execute_model( self, model_forward_batch: Optional[List[Request]] = None, @@ -2069,14 +2291,25 @@ def execute_model_normal( and self.parallel_config.use_ep ): self._execute_empty_mtp_input(self.forward_meta) - return - model_output_data, sampler_output, post_process_event = self._postprocess( - model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz - ) - if model_output_data is not None: - # synchronizes the async DtoH copies of sampled_token_ids. - post_process_event.synchronize() - self._save_model_output(model_output_data, sampler_output) + + if paddle.sum(self.share_inputs["preempted_idx"]) > 0: + logger.info( + f"All requests in batch are preempted, real_bsz: {real_bsz} preempted: {paddle.sum(self.share_inputs['preempted_idx'])}" + ) + model_output_data, sampler_output = self._make_preempted_batch_output() + self.share_inputs["last_preempted_idx"].copy_(self.share_inputs["preempted_idx"]) + self.share_inputs["preempted_idx"][:] = 0 + self._save_model_output(model_output_data, sampler_output) + else: + model_output_data, sampler_output, post_process_event = self._postprocess( + model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz + ) + if model_output_data is not None: + # synchronizes the async DtoH copies of sampled_token_ids. + post_process_event.synchronize() + if self.routing_replay_manager is not None: + self.routing_replay_manager.flush_pending_save() + self._save_model_output(model_output_data, sampler_output) def execute_model_overlap( self, @@ -2092,6 +2325,8 @@ def execute_model_overlap( if self._cached_model_output_data is not None: # synchronizes the async DtoH copies of sampled_token_ids. self._cached_post_process_event.synchronize() + if self.routing_replay_manager is not None: + self.routing_replay_manager.flush_pending_save() self._save_model_output( self._cached_model_output_data, self._cached_sampler_output, @@ -2102,7 +2337,10 @@ def execute_model_overlap( # ensuring that the token count for the current batch is ready to be computed and reused in the subsequent batch. token_num_event.synchronize() next_launch_token_num, next_real_bsz = self._predict_next_launch_token_num() - real_bsz = (self.share_inputs["seq_lens_this_time_cpu"].numpy() > 0).sum().item() + seq_lens_this_time_cpu_numpy = self.share_inputs["seq_lens_this_time_cpu"].numpy() + real_bsz = (seq_lens_this_time_cpu_numpy > 0).sum().item() + if self.routing_replay_manager is not None: + self.routing_replay_manager.token_num_overlap = seq_lens_this_time_cpu_numpy.sum().item() if real_bsz > 0 and model_output is not None: model_output_data, sampler_output, post_process_event = self._postprocess( model_output, p_done_idxs, model_forward_batch, num_running_requests, real_bsz @@ -2111,6 +2349,22 @@ def execute_model_overlap( self._cached_sampler_output = sampler_output self._cached_post_process_event = post_process_event else: + if ( + self.fd_config.speculative_config.method == SpecMethod.MTP + and hasattr(self.proposer.model, "empty_input_forward") + and self.parallel_config.use_ep + ): + self._execute_empty_mtp_input(self.forward_meta) + + if paddle.sum(self.share_inputs["preempted_idx"]) > 0: + logger.info( + f"All requests in batch are preempted, real_bsz: {real_bsz} preempted: {paddle.sum(self.share_inputs['preempted_idx'])}" + ) + model_output_data, sampler_output = self._make_preempted_batch_output() + self.share_inputs["last_preempted_idx"].copy_(self.share_inputs["preempted_idx"]) + self.share_inputs["preempted_idx"][:] = 0 + self._save_model_output(model_output_data, sampler_output) + self._cached_model_output_data = None self._cached_sampler_output = None self._cached_post_process_event = None @@ -2144,11 +2398,6 @@ def _preprocess( p_done_idxs = self._get_p_done_idxs_gd(model_forward_batch, num_running_requests) self.sampler.pre_process(p_done_idxs) - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.pending_update_positions = self.routing_replay_manager.get_token_positions( - seq_lens_decoder=self.share_inputs["seq_lens_decoder"], - seq_lens_this_time=self.share_inputs["seq_lens_this_time_buffer"], - ) # Update state of logits processor for proc in self.sampling_metadata.logits_processors: @@ -2156,6 +2405,8 @@ def _preprocess( # Padding inputs for cuda graph self.padding_cudagraph_inputs() + # Compute position_ids and slot_mapping + self._compute_position_ids_and_slot_mapping(total_token_num=current_launch_token_num) model_inputs = {} model_inputs["ids_remove_padding"] = self.share_inputs["ids_remove_padding"] @@ -2321,6 +2572,7 @@ def _postprocess( self.share_inputs, real_output_token_num, self.increment_value, + real_bsz=real_bsz, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( @@ -2404,6 +2656,16 @@ def _postprocess( # 5.1. Async cpy post_process_event = paddle.device.cuda.create_event() + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + # If one query is preempted, there is no sampled token for it, we use token_id PREEMPTED_TOKEN_ID to signal server, abort is finished. + paddle.assign( + paddle.where( + self.share_inputs["last_preempted_idx"][: sampler_output.sampled_token_ids.shape[0]] == 1, + PREEMPTED_TOKEN_ID, + sampler_output.sampled_token_ids, + ), + sampler_output.sampled_token_ids, + ) # if not self.speculative_decoding: self.share_inputs["sampled_token_ids"].copy_(sampler_output.sampled_token_ids, False) if self.speculative_decoding: @@ -2463,13 +2725,18 @@ def _save_model_output( sampler_output, ): if self.speculative_decoding: - skip_save_output = self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" - save_output_specualate( + save_output_speculate( sampler_output=sampler_output, model_output=model_output_data, share_inputs=self.share_inputs, + local_rank=self.local_rank, + tensor_parallel_rank=self.parallel_config.tensor_parallel_rank, save_each_rank=self.parallel_config.use_ep, - skip_save_output=skip_save_output, + sampling_mask_async_queue=self.sampling_mask_async_queue, + is_mtp_prefill=( + self.spec_method == SpecMethod.MTP and self.scheduler_config.splitwise_role == "prefill" + ), + proposer_share_inputs=self.proposer.model_inputs if self.spec_method == SpecMethod.MTP else None, ) else: save_output_normal( @@ -2478,6 +2745,7 @@ def _save_model_output( share_inputs=self.share_inputs, async_output_queue=self.async_output_queue, save_each_rank=self.parallel_config.use_ep, + sampling_mask_async_queue=self.sampling_mask_async_queue, ) def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Optional[ModelRunnerOutput]: @@ -2657,16 +2925,23 @@ def cal_theortical_kvcache(self): def clear_cache(self, profile=False): """Clear cached data from shared inputs and forward metadata""" create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not create_cache_tensor: - for name, tensor in self.cache_kvs_map.items(): - unset_data_ipc(tensor, name, True, False) - self.cache_ready_signal.value[local_rank] = 0 + if not profile: + if create_cache_tensor: + if self.fd_config.cache_config.num_cpu_blocks > 0: + logger.info("Waiting for cache transfer manager to unlink cuda ipc") + while self.cache_ready_signal.value[local_rank] != 0: + time.sleep(0.1) + logger.info("Stop waiting! cache transfer manager has unlinked cuda ipc") + else: + for name, tensor in self.cache_kvs_map.items(): + unset_data_ipc(tensor, name, True, False) + self.cache_ready_signal.value[local_rank] = 0 + self.cache_kvs_map.clear() self.share_inputs.pop("caches", None) if self.forward_meta is not None: @@ -2677,17 +2952,31 @@ def clear_parameters(self, pid): """Dynamic model loader use to clear parameters use for RL""" # Clear CUDAGraph if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() + if ( + self.speculative_decoding + and self.spec_method == SpecMethod.MTP + and self.graph_opt_config.draft_model_use_cudagraph + ): + self.proposer.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle ) - if self.spec_method == SpecMethod.MTP: - self.proposer.model.clear_grpah_opt_backend() + + # NOTE(wangyanpeng): MTP cache must be cleared before clearing the main KV cache + if self.speculative_decoding and self.spec_method == SpecMethod.MTP: self.proposer.clear_mtp_cache() self.clear_cache() paddle.device.cuda.empty_cache() + # clear overlap status + self._cached_model_output_data = None + self._cached_sampler_output = None + self._cached_post_process_event = None + self._cached_launch_token_num = -1 + self._cached_real_bsz = -1 + self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") def clear_requests(self): @@ -2700,14 +2989,21 @@ def clear_requests(self): # Routing Replay if self.routing_replay_manager: - self.routing_replay_manager.clear_all_request() + self.routing_replay_manager.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" # Update parameters - self.dynamic_weight_manager.update_parameters( - pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle - ) + if self.dynamic_weight_manager.use_gdr_checkpoint_transfer: + if self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle: + self.dynamic_weight_manager.restart_communication_group() + if self.dynamic_weight_manager.parallel_config.enable_expert_parallel: + self.dynamic_weight_manager.recreate_deepep_buffer() + self.dynamic_weight_manager.update_weights_by_gdr(restore_cleared_params=True) + else: + self.dynamic_weight_manager.update_parameters( + pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle + ) # Reset share_inputs self.share_inputs.reset_share_inputs() @@ -2718,13 +3014,95 @@ def update_parameters(self, pid): # Recapture CUDAGraph if self.use_cudagraph: self.capture_model() + # Send single self.dynamic_weight_manager.finalize_update(pid) - self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") def update_weights(self, version: str = None, verify_checksum: bool = False): - return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum) + if self.dynamic_weight_manager.use_gdr_checkpoint_transfer: + release_cache = bool((self.fd_config.load_config.rsync_config or {}).get("gdr_release_cache", False)) + + cache_clear_cost = 0.0 + cache_rebuild_cost = 0.0 + if release_cache: + clear_start = time.perf_counter() + self._clear_cache_for_gdr_weight_update() + cache_clear_cost = time.perf_counter() - clear_start + + result = self.dynamic_weight_manager.update_weights_by_gdr(version, verify_checksum) + + if release_cache: + rebuild_start = time.perf_counter() + self._rebuild_cache_after_gdr_weight_update() + cache_rebuild_cost = time.perf_counter() - rebuild_start + + result["release_cache"] = release_cache + result["cache_clear_cost"] = cache_clear_cost + result["cache_rebuild_cost"] = cache_rebuild_cost + self.dynamic_weight_manager.finalize_update() + return result + else: + result = self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum) + self.dynamic_weight_manager.finalize_update() + return result + + def _clear_cache_for_gdr_weight_update(self): + cache_flag = ( + self.fd_config.cache_config.num_cpu_blocks > 0 + or self.fd_config.cache_config.kvcache_storage_backend is not None + ) + kv_cache_status = self.kv_cache_status if cache_flag else None + if kv_cache_status: + kv_cache_status.value[0] = KVCacheStatus.CLEARING + if self.use_cudagraph: + self.model.clear_graph_opt_backend() + if envs.FD_USE_BLOCK_WISE_CUDA_GRAPH: + from fastdeploy.model_executor.graph_optimization.cuda_graph_op import ( + clear_all_block_wise_graphs, + ) + + clear_all_block_wise_graphs() + if ( + self.speculative_decoding + and self.spec_method == SpecMethod.MTP + and self.graph_opt_config.draft_model_use_cudagraph + ): + self.proposer.model.clear_graph_opt_backend() + if self.speculative_decoding and self.spec_method == SpecMethod.MTP: + self.proposer.clear_mtp_cache() + self.clear_cache() + if kv_cache_status: + while kv_cache_status.value[0] != KVCacheStatus.CLEARED: + time.sleep(0.01) + paddle.device.cuda.empty_cache() + self._cached_model_output_data = None + self._cached_sampler_output = None + self._cached_post_process_event = None + self._cached_launch_token_num = -1 + self._cached_real_bsz = -1 + + def _rebuild_cache_after_gdr_weight_update(self): + cache_flag = ( + self.fd_config.cache_config.num_cpu_blocks > 0 + or self.fd_config.cache_config.kvcache_storage_backend is not None + ) + kv_cache_status = self.kv_cache_status if cache_flag else None + if kv_cache_status: + kv_cache_status.value[0] = KVCacheStatus.UPDATING + self.share_inputs.reset_share_inputs() + if self.spec_method == SpecMethod.MTP: + self.proposer.model_inputs.reset_model_inputs() + if not self.enable_cache_manager_v1: + self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) + self.initialize_kv_cache() + if self.use_cudagraph: + self.capture_model() + if self.fd_config.routing_replay_config.enable_routing_replay: + self.routing_replay_manager.update_suspend_routing_replay() + if kv_cache_status: + while kv_cache_status.value[0] != KVCacheStatus.NORMAL: + time.sleep(0.01) def sleep(self, tags): @@ -2737,7 +3115,7 @@ def sleep(self, tags): logger.info("GPU model runner's weight is already sleeping, no need to sleep again!") return if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() if self.fd_config.parallel_config.enable_expert_parallel: self.dynamic_weight_manager.clear_deepep_buffer() self.dynamic_weight_manager.clear_model_weight() @@ -3103,6 +3481,5 @@ def initialize_routing_replay_manager(self): # Use updated block number self.routing_replay_manager = RoutingReplayManager( fd_config=self.fd_config, - block_table=self.share_inputs["block_tables"], total_block_num=self.num_gpu_blocks, ) diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index aebf3f21111..423d9fb54a5 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -126,14 +126,12 @@ def determine_available_memory(self) -> int: before_run_meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) logger.info( - ( - "Before running the profile, the memory usage info is as follows:", - f"\nDevice Total memory: {before_run_meminfo.total / Gb}", - f"\nDevice used memory: {before_run_meminfo.used / Gb}", - f"\nDevice free memory: {before_run_meminfo.free / Gb}", - f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}", - f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}", - ) + "Before running the profile, the memory usage info is as follows:" + f"\nDevice Total memory: {before_run_meminfo.total / Gb}" + f"\nDevice used memory: {before_run_meminfo.used / Gb}" + f"\nDevice free memory: {before_run_meminfo.free / Gb}" + f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}" + f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}" ) # 2. Profile run @@ -161,16 +159,14 @@ def determine_available_memory(self) -> int: end_time = time.perf_counter() logger.info( - ( - "After running the profile, the memory usage info is as follows:", - f"\nDevice Total memory: {after_run_meminfo.total / Gb}", - f"\nDevice used memory: {after_run_meminfo.used / Gb}", - f"\nDevice free memory: {after_run_meminfo.free / Gb}", - f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}", - f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}", - f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}", - f"Profile time: {end_time - start_time}", - ) + "After running the profile, the memory usage info is as follows:" + f"\nDevice Total memory: {after_run_meminfo.total / Gb}" + f"\nDevice used memory: {after_run_meminfo.used / Gb}" + f"\nDevice free memory: {after_run_meminfo.free / Gb}" + f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}" + f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}" + f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}" + f"Profile time: {end_time - start_time}" ) return available_kv_cache_memory # return to calculate the block num in this device diff --git a/fastdeploy/worker/iluvatar_worker.py b/fastdeploy/worker/iluvatar_worker.py index 625aca86db1..44be900bb73 100644 --- a/fastdeploy/worker/iluvatar_worker.py +++ b/fastdeploy/worker/iluvatar_worker.py @@ -40,7 +40,7 @@ def __init__( local_rank: int, rank: int, ): - if fd_config.model_config.enable_mm: + if fd_config.enable_mm_runtime: paddle.set_flags({"FLAGS_enable_ixattnbkd": True, "FLAGS_enable_ixdnn_attn": False}) super(IluvatarWorker, self).__init__( fd_config=fd_config, diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 363dfb63097..53f6db55649 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -17,13 +17,7 @@ import paddle from paddleformers.utils.log import logger -from fastdeploy.config import ( - CacheConfig, - DeployModality, - FDConfig, - ModelConfig, - SpeculativeConfig, -) +from fastdeploy.config import CacheConfig, FDConfig, ModelConfig, SpeculativeConfig from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.logits_processor import build_logits_processors from fastdeploy.platforms import current_platform @@ -101,7 +95,8 @@ def __init__(self, fd_config: FDConfig) -> None: self.scheduler_config = fd_config.scheduler_config self.speculative_config: SpeculativeConfig = fd_config.speculative_config self.speculative_decoding = self.speculative_config.method is not None - self.enable_mm = self.model_config.enable_mm + self.is_mm_model = self.model_config.enable_mm + self.enable_mm = fd_config.enable_mm_runtime self.enable_expert_parallel = fd_config.parallel_config.enable_expert_parallel self.index_to_batch_id = {} self.enable_pd_reorder = False @@ -126,6 +121,7 @@ def init_share_inputs(self): ) self.eos_token_id = paddle.full([self.model_config.eos_tokens_lens, 1], 0, dtype="int64") self.top_p = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") + self.top_p_list = [self.model_config.top_p] * max_num_seqs self.top_k = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.top_k_list = [0] * max_num_seqs self.min_p = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") @@ -193,6 +189,10 @@ def init_share_inputs(self): self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32") self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32") + # Initialize addressing buffers + self.position_ids_buffer = paddle.zeros([self.max_chunk_tokens], dtype=paddle.int64) + self.slot_mapping_buffer = paddle.zeros([self.max_chunk_tokens], dtype=paddle.int64) + # Declare AttentionBackend buffers self.decoder_batch_ids = None self.decoder_tile_ids_per_batch = None @@ -206,6 +206,13 @@ def init_share_inputs(self): self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU + # Decode attention split ops buffers (initialized by _initialize_attn_backend) + self.decode_block_indices = None + self.decode_num_blocks = None + self.decode_chunk_size = None + self.decode_tmp_workspace = None + self.decode_tmp_m = None + self.decode_tmp_d = None # Initialize thinking related buffers self.enable_thinking = paddle.full(shape=[max_num_seqs, 1], fill_value=True, dtype="bool") @@ -231,6 +238,9 @@ def init_share_inputs(self): model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, ) + if self.is_mm_model: + self.image_features = None + self.image_features_list = None # Set block tables pre_max_block_num = ( @@ -394,6 +404,7 @@ def swap_data(tensor, idx1, idx2): # swap_data(self.recompute_token_num, i1, i2) # # Swap list-based arrays (lists don't need clone) + self.top_p_list[i1], self.top_p_list[i2] = self.top_p_list[i2], self.top_p_list[i1] self.top_k_list[i1], self.top_k_list[i2] = self.top_k_list[i2], self.top_k_list[i1] self.min_p_list[i1], self.min_p_list[i2] = self.min_p_list[i2], self.min_p_list[i1] @@ -545,6 +556,7 @@ def reset_share_inputs(self): fill_paddle_tensor(self, "top_p_normalized_logprobs", False) # Reset list variables (not paddle tensors) + self.top_p_list = [self.model_config.top_p] * max_num_seqs self.top_k_list = [0] * max_num_seqs self.min_p_list = [0.0] * max_num_seqs @@ -562,7 +574,6 @@ def reset_share_inputs(self): fill_paddle_tensor(self, "step_idx", 0) # fill_paddle_tensor(self, "not_need_stop", False) fill_paddle_tensor(self, "not_need_stop_device", False) - fill_paddle_tensor(self, "sampled_token_ids", -1) fill_paddle_tensor(self, "stop_flags", True) fill_paddle_tensor(self, "bad_tokens", -1) @@ -677,10 +688,19 @@ def reset_share_inputs(self): model_config=self.model_config, partial_rotary_factor=self.model_config.partial_rotary_factor, ) + if self.is_mm_model: + self.image_features = None + self.image_features_list = None # Reset other miscellaneous tensors fill_paddle_tensor(self, "mask_rollback", 0) fill_paddle_tensor(self, "preempted_idx", 0) + fill_paddle_tensor(self, "last_preempted_idx", 0) + + # Reset tensors for overlap + self.sampled_token_ids = paddle.full([max_num_seqs, 1], -1, dtype="int64").pin_memory() + self.seq_lens_this_time_cpu = paddle.full([max_num_seqs, 1], 0, dtype="int32").pin_memory() + self.is_block_step_cpu = paddle.full([max_num_seqs], False, dtype="bool").pin_memory() logger.info("share_inputs reset completed") except Exception as e: @@ -689,7 +709,7 @@ def reset_share_inputs(self): class ProposerInputBatch(InputBatch): def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) -> None: - self.enable_mm = fd_config.model_config.enable_mm + self.enable_mm = fd_config.enable_mm_runtime self.num_model_steps = fd_config.speculative_config.num_model_steps self.index_to_batch_id = {} self.target_model_input_batch = target_model_input_batch @@ -804,6 +824,13 @@ def init_share_inputs(self): self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU + # Decode attention split ops buffers + self.decode_block_indices = None + self.decode_num_blocks = None + self.decode_chunk_size = None + self.decode_tmp_workspace = None + self.decode_tmp_m = None + self.decode_tmp_d = None # Input tokens self.draft_tokens = paddle.full( @@ -863,18 +890,15 @@ def init_share_inputs(self): -1, dtype="int32", ) - if self.fd_config.deploy_modality != DeployModality.TEXT: - self.attn_mask_offsets = paddle.full( - shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len], - fill_value=-1, - dtype="int32", - ) - self.attn_mask_offsets_full = paddle.full( - [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32" - ) - self.attn_mask_offsets_decoder = paddle.full( - [self.scheduler_config.max_num_seqs, 1], -1, dtype="int32" - ) + self.attn_mask_offsets = paddle.full( + shape=[self.scheduler_config.max_num_seqs * self.model_config.max_model_len], + fill_value=-1, + dtype="int32", + ) + self.attn_mask_offsets_full = paddle.full( + [self.scheduler_config.max_num_seqs, self.model_config.max_model_len], -1, dtype="int32" + ) + self.attn_mask_offsets_decoder = paddle.full([self.scheduler_config.max_num_seqs, 1], -1, dtype="int32") def swap_states(self, i1, i2) -> None: def swap_data(tensor, idx1, idx2): @@ -896,7 +920,7 @@ def swap_data(tensor, idx1, idx2): swap_data(self.input_ids_len, i1, i2) swap_data(self.mask_rollback, i1, i2) swap_data(self.recompute_token_num, i1, i2) - if self.enable_mm and self.fd_config.deploy_modality != DeployModality.TEXT: + if self.enable_mm: swap_data(self.attn_mask_offsets_full, i1, i2) swap_data(self.attn_mask_offsets_decoder, i1, i2) @@ -915,8 +939,12 @@ def reset_model_inputs(self) -> None: self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"]) self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"]) fill_paddle_tensor(self, "input_ids_cpu", -1) - # acceptance rate decline when reset seq_lens_this_time - # self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"]) + # NOTE(fix): Must reset seq_lens_this_time_buffer to avoid stale values from previous + # RL round causing illegal memory access during CUDAGraph recapture (error 700). + # When draft_model_use_cudagraph=true, padding_cudagraph_inputs() uses the full + # seq_lens_this_time_buffer tensor; residual non-zero values in high-index slots + # (from previous round) will make attention kernel access invalid block_table entries. + fill_paddle_tensor(self, "seq_lens_this_time_buffer", 0) self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"]) self.seq_lens_decoder = paddle.clone(self.target_model_input_batch["seq_lens_decoder"]) @@ -925,8 +953,19 @@ def reset_model_inputs(self) -> None: self.step_idx = paddle.clone(self.target_model_input_batch["step_idx"]) self.stop_flags = paddle.clone(self.target_model_input_batch["stop_flags"]) self.not_need_stop = paddle.to_tensor([False], dtype="bool", place="cpu") + self.not_need_stop_device = paddle.to_tensor([False], dtype="bool") self.index_to_batch_id = {} if current_platform.is_cuda(): + # NOTE(fix): These tensors get reshaped during runtime inference, so we must + # recreate them at full initial size instead of cloning the (possibly resized) + # target_model_input_batch tensors. Otherwise CUDAGraph replay will write + # beyond tensor boundaries causing CUDA error(700). + max_num_seqs = self.scheduler_config.max_num_seqs + max_draft_token_num = self.speculative_config.num_speculative_tokens + self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32") + self.batch_id_per_token_output = paddle.full( + shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32" + ) if "token_ids_all" in self.target_model_input_batch: self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"]) # TODO: delete pre_ids in mtp @@ -946,13 +985,28 @@ def reset_model_inputs(self) -> None: self.token_ids_all = None else: self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) - self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"]) - self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"]) - self.cu_seqlens_q = paddle.clone(self.target_model_input_batch["cu_seqlens_q"]) - self.cu_seqlens_k = paddle.clone(self.target_model_input_batch["cu_seqlens_k"]) - # Reset target hidden states - fill_paddle_tensor(self, "target_hidden_states", 0) + # NOTE(fix): These tensors are dynamically resized during runtime inference. + # Must recreate at full initial size to avoid CUDAGraph replay OOB access. + max_num_seqs = self.scheduler_config.max_num_seqs + if self.enable_mm and self.model_config.mm_max_tokens_per_item is None: + self.max_chunk_tokens = self.model_config.max_model_len + else: + self.max_chunk_tokens = self.fd_config.get_max_chunk_tokens(self.model_config.mm_max_tokens_per_item) + self.ids_remove_padding = paddle.full([max_num_seqs * self.max_chunk_tokens], 0, dtype="int64") + self.batch_id_per_token = paddle.full([max_num_seqs * self.max_chunk_tokens, 1], 0, dtype="int32") + self.cu_seqlens_q = paddle.full([max_num_seqs + 1], 0, dtype="int32") + self.cu_seqlens_k = paddle.full([max_num_seqs + 1], 0, dtype="int32") + + # Reset target hidden states - must recreate at full size + self.target_hidden_states = paddle.full( + [ + self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens, + self.model_config.hidden_size, + ], + 0, + dtype="bfloat16", + ) # Reset rope embedding by recreating with default position_ids tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) @@ -1030,10 +1084,9 @@ def reset_model_inputs(self) -> None: # Reset multimodal tensors if enabled if self.enable_mm: fill_paddle_tensor(self, "decode_states", -1) - if self.fd_config.deploy_modality != DeployModality.TEXT: - fill_paddle_tensor(self, "attn_mask_offsets", -1) - fill_paddle_tensor(self, "attn_mask_offsets_full", -1) - fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1) + fill_paddle_tensor(self, "attn_mask_offsets", -1) + fill_paddle_tensor(self, "attn_mask_offsets_full", -1) + fill_paddle_tensor(self, "attn_mask_offsets_decoder", -1) logger.info("model_inputs reset completed") except Exception as e: @@ -1201,3 +1254,13 @@ def recover_batch_index_for_sampler_output(sampler_output, index_to_batch_id, en logits = sampler_output.logits real_logits = _recover_tensor(logits, src_order) sampler_output.logits = real_logits + + if sampler_output.sampling_mask is not None: + sampling_mask = sampler_output.sampling_mask + sort_len = len(src_order) + real_sampling_mask = [None] * len(sampling_mask) + for i in range(sort_len): + real_sampling_mask[i] = sampling_mask[src_order[i]] + for i in range(sort_len, len(sampling_mask)): + real_sampling_mask[i] = sampling_mask[i] + sampler_output.sampling_mask = real_sampling_mask diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index 93f5cec6a57..2673386a927 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -51,9 +51,6 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) -from fastdeploy.model_executor.layers.moe.routing_indices_cache import ( - RoutingReplayManager, -) from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata @@ -97,7 +94,7 @@ def __init__( ): super().__init__(fd_config=fd_config, device=device) self.MAX_INFER_SEED = 9223372036854775806 - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id @@ -203,8 +200,6 @@ def __init__( # Rollout routing replay config self.routing_replay_manager = None - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager = RoutingReplayManager(fd_config=self.fd_config) self.zmq_client = None self.async_output_queue = None @@ -786,11 +781,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.forward_batch_reqs_list[idx] = request has_prefill_task = True - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - if prefill_start_index == 0: - self.routing_replay_manager.register_request(batch_id=idx, request_id=request.request_id) - if ( self.fd_config.scheduler_config.splitwise_role == "decode" ): # In PD, we continue to decode after P generate first token @@ -822,10 +812,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.in_progress_prompt_logprobs.pop(request.request_id, None) self.forward_batch_reqs_list[idx] = None - # Routing Replay - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_request(batch_id=idx) - continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -1239,9 +1225,9 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): Initialize forward meta, attention meta data and update some config. """ # Initialize forward meta - routing_replay_table = None + device_routing_buffer = None if self.routing_replay_manager is not None: - routing_replay_table = self.routing_replay_manager.get_routing_table() + device_routing_buffer = self.routing_replay_manager.get_device_routing_buffer() self.forward_meta = ForwardMeta( ids_remove_padding=self.share_inputs["ids_remove_padding"], rotary_embs=self.share_inputs["rope_emb"], @@ -1268,7 +1254,8 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_batch_ids=self.share_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], - routing_replay_table=routing_replay_table, + routing_replay_table=None, + device_routing_buffer=device_routing_buffer, ) dist_status = self.collect_distributed_status() @@ -1345,9 +1332,10 @@ def initialize_kv_cache(self, profile: bool = False) -> None: # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. # 2. If no need to profile, create kv cache if cache managers do not exist. + # Note: even when CPU cache (num_cpu_blocks > 0) is enabled, GPU runner still + # creates GPU cache tensors; cache transfer manager handles CPU<->GPU swap. create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) @@ -1464,6 +1452,8 @@ def _initialize_attn_backend(self) -> None: num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, + head_dim=head_dim, + dtype=self.model_config.dtype, ) self.share_inputs.update(res_buffer) @@ -1790,8 +1780,8 @@ def _dummy_run( # only need to capture prefill break - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.clear_routing_table() + if self.fd_config.routing_replay_config.enable_routing_replay and self.routing_replay_manager is not None: + self.routing_replay_manager.clear() @sot_warmup_guard(True) def capture_model(self) -> None: @@ -2302,7 +2292,7 @@ def _postprocess( and self.share_inputs["is_block_step"].sum() == 0 and self.share_inputs["is_chunk_step"].sum() == 0 ): - self.routing_replay_manager.put_table_to_store() + pass # Routing store submission now handled by RoutingCacheManager on Engine side return model_output_data, sampler_output, post_process_done def _save_model_output( @@ -2491,8 +2481,7 @@ def not_need_stop(self) -> bool: def clear_cache(self, profile=False): """Clear cached data from shared inputs and forward metadata""" create_cache_tensor = profile or not ( - self.fd_config.cache_config.num_cpu_blocks > 0 - or self.fd_config.cache_config.kvcache_storage_backend + self.fd_config.cache_config.kvcache_storage_backend or self.fd_config.scheduler_config.splitwise_role != "mixed" ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size @@ -2511,7 +2500,7 @@ def clear_parameters(self, pid): """Dynamic model loader use to clear parameters use for RL""" # Clear CUDAGraph if self.use_cudagraph: - self.model.clear_grpah_opt_backend() + self.model.clear_graph_opt_backend() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle @@ -2530,8 +2519,8 @@ def clear_requests(self): self.prompt_logprobs_reqs.clear() self.in_progress_prompt_logprobs.clear() self.forward_batch_reqs_list = [None for _ in range(self.scheduler_config.max_num_seqs)] - if self.fd_config.routing_replay_config.enable_routing_replay: - self.routing_replay_manager.put_table_to_store() + if self.fd_config.routing_replay_config.enable_routing_replay and self.routing_replay_manager is not None: + self.routing_replay_manager.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 365fec12475..44cc9cb9e16 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -15,8 +15,9 @@ """ from dataclasses import dataclass, field -from typing import NamedTuple, Optional +from typing import List, NamedTuple, Optional +import numpy as np import paddle @@ -178,6 +179,24 @@ class SamplerOutput: token_num_per_batch: Optional[paddle.Tensor] = None cu_batch_token_offset: Optional[paddle.Tensor] = None logits: Optional[paddle.Tensor] = None + # Sparse sampling mask for top_p/top_k: + # - Non-speculative decoding: per-request mask. This is a list of length + # num_reqs, where element i is a 1-D int32 numpy array of vocab indices + # retained by top_p/top_k for request i. Replaces the previous dense + # [num_reqs, vocab_size] bool tensor. + # - Speculative decoding: flattened per-accepted-token mask. This may be + # stored as a list aligned with all accepted tokens + # (e.g. length = total_accepted_tokens) and is regrouped by accept_num + # (number of accepted tokens per request) in post-processing before + # being sent back as per-request data. + # Callers MUST NOT assume this is always shaped by num_reqs; they should + # check whether the current path is speculative or non-speculative when + # interpreting the dimension. + sampling_mask: Optional[List[np.ndarray]] = None + # logZ_K for each request: log(sum(probs in candidate set K)) + # Used for renormalizing logprobs to match the truncated sampling distribution. + # Shape: [num_reqs] + logz_per_batch: Optional[np.ndarray] = None @dataclass diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 8182e06990b..c69dd2c859f 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -138,7 +138,7 @@ def init_distributed_environment(seed: int = 20) -> Tuple[int, int]: def update_fd_config_for_mm(fd_config: FDConfig) -> None: architectures = fd_config.model_config.architectures - if fd_config.model_config.enable_mm and ErnieArchitectures.contains_ernie_arch(architectures): + if fd_config.enable_mm_runtime and ErnieArchitectures.contains_ernie_arch(architectures): fd_config.model_config.tensor_model_parallel_size = fd_config.parallel_config.tensor_parallel_size fd_config.model_config.tensor_parallel_rank = fd_config.parallel_config.tensor_parallel_rank fd_config.model_config.vision_config.dtype = fd_config.model_config.dtype @@ -174,6 +174,10 @@ def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) -> self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 self.enable_overlap_schedule = self.scheduler_config.enable_overlap_schedule self.cached_control_reqs = [] + if self.ranks > 1: + self.gloo_group = dist.new_group(list(range(self.ranks)), backend="gloo") + else: + self.gloo_group = None def init_control(self): engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port @@ -256,6 +260,7 @@ def init_health_status(self) -> None: suffix=self.parallel_config.local_engine_worker_queue_port, create=False, ) + self.worker.model_runner.kv_cache_status = self.kv_cache_status # init exist_task_signal workers_exist_task = np.zeros([1], dtype=np.int32) @@ -287,19 +292,6 @@ def init_health_status(self) -> None: create=False, ) - # init engine forward signal - # If engine is being forward, engine_forward_signal_data should be 1. - # If engine is out of forward, engine_forward_signal_data should be 0. - # In pd disaggregation + EP parallel, only when engine is out of forward, scheduler send next batch to worker. - # When engine is out of forward, engine_forward_signal_data must be 0, otherwise scheduler will not schedule next batch. - engine_forward_signal_data = np.zeros([1], dtype=np.int32) - self.engine_forward_signal = IPCSignal( - name="engine_forward_signal", - array=engine_forward_signal_data, - dtype=np.int32, - suffix=self.parallel_config.local_engine_worker_queue_port, - create=False, - ) # gpu_cache_lock: file-based lock for mutual exclusion between worker # and CPU transfer when accessing GPU KV cache. self.gpu_cache_lock = IPCLock( @@ -329,11 +321,25 @@ def update_weights_from_tensor(self, mmap_infos): self.experts_manager.tensor_infos = None def _broadcast_model_weights_signal(self, src: int, group) -> int: - model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32") + model_weights_signal_tensor = paddle.full( + shape=[1], fill_value=self.model_weights_signal[0], dtype="int32", device="cpu" + ) paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group) value = model_weights_signal_tensor.numpy()[0] return int(value) + def _get_exist_task_flag(self) -> bool: + if self.nnode > 1: + return self.task_queue.read_finish_flag.get() == 1 + else: + return self.exist_task_signal.value[0] == ExistTaskStatus.EXIST + + def _update_exist_task_flag(self, flag: bool) -> None: + if self.nnode > 1: + self.task_queue.read_finish_flag.set(1 if flag else 0) + else: + self.exist_task_signal.value[0] = ExistTaskStatus.EXIST if flag else ExistTaskStatus.EMPTY + def _tp_barrier_wait(self): if current_platform.is_xpu() or self.enable_overlap_schedule: self.task_queue.worker_process_tp_barrier.wait() @@ -487,39 +493,43 @@ def event_loop_normal(self) -> None: self._init_eplb_signal() tp_size = self.parallel_config.tensor_parallel_size # Currently, only support single node - self.nnode = (tp_size + self.max_chips_per_node) // self.max_chips_per_node + self.nnode = (tp_size + self.max_chips_per_node - 1) // self.max_chips_per_node max_occupied_batch_index = 0 tp_rank = self.local_rank % tp_size # TODO: Unify status variables model_weights_status (shared memory) and model_weights_signal (numpy array) to one self.model_weights_signal = np.zeros([1], dtype=np.int32) while True: + # run eplb + self._run_eplb(tp_rank) + if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: self.model_weights_signal[0] = int(self.model_weights_status.value[0]) if self.ranks > 1: - self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None) + self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=self.gloo_group) req_dicts = None self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time()) + self._tp_barrier_wait() if tp_size > 1 else None + # The first worker detects whether there are tasks in the task queue if tp_rank == 0: if self.task_queue.exist_tasks(): if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( - self.fd_config.model_config.enable_mm and self.worker.exist_prefill() + self.fd_config.enable_mm_runtime and self.worker.exist_prefill() ): - if self.nnode > 1: - self.task_queue.read_finish_flag.set(1) - else: - self.exist_task_signal.value[0] = ExistTaskStatus.EXIST + self._update_exist_task_flag(True) + else: + self._update_exist_task_flag(False) # Synchronize the signal set by tp_rank0 visiable to other workers self._tp_barrier_wait() if tp_size > 1 else None if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS: - if self.ranks > 1: - paddle.distributed.barrier() if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL: + if self.ranks > 1: + paddle.distributed.barrier() logger.info( f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]" ) @@ -561,7 +571,7 @@ def event_loop_normal(self) -> None: self.model_weights_signal[0] = self.model_weights_status.value[0] if self.ranks > 1: self.model_weights_signal[0] = self._broadcast_model_weights_signal( - src=0, group=None + src=0, group=self.gloo_group ) time.sleep(1) self.model_weights_status.value[0] = ( @@ -569,50 +579,33 @@ def event_loop_normal(self) -> None: ) # 所有 Rank 已同步唤醒,启动权重更新流程 continue - if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1: + if self._get_exist_task_flag(): logger.debug(f"Rank: {self.local_rank} Detected new requests.") - self.engine_forward_signal.value[0] = 1 + tasks, read_finish = self.task_queue.get_tasks() # Only one of all tp_size client will get read_finish == True. if read_finish: - # Reset the two signal. - if self.nnode > 1: - self.task_queue.read_finish_flag.set(0) - else: - self.exist_task_signal.value[0] = ExistTaskStatus.EMPTY - # In EP parallel(corresponing to dp attention), we need to barrier for prefill to prevent data imbalance due to inconsistent data arrival. - # Only EP + DP prefill should barrier for data arrival. - # In mixed mode and decoder in D, we should not barrier to influence decoding. - if self.parallel_config.use_ep and self.scheduler_config.splitwise_role == "prefill": - paddle.distributed.barrier(self.parallel_config.ep_group) + self._update_exist_task_flag(False) + self._tp_barrier_wait() if tp_size > 1 else None req_dicts, control_reqs = [], [] - assert ( - len(tasks) > 0 - ), f"task_queue.get_tasks() should contain at least one tuple, [([req1, ...] ,real_bsz)], but got len(tasks)={len(tasks)}" - # In EP + DP prefill, empty task ([]) is delived in worker to barrier. For empty task, just skip and continue. - # tasks[0] contains two part, ([req1, ...] ,real_bsz) - # tasks[0][0] is [req1, ...] - # if empty batch is delived, eval(tasks[0][0]) should be False ([]), - # if batch with requests is delived, eval(tasks[0][0]) should be True, then to be processed as below. - if tasks[0][0]: - for req_dict, bsz in tasks: - if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest): - control_reqs.append(req_dict[0]) + for req_dict, bsz in tasks: + if len(req_dict) > 0 and isinstance(req_dict[0], ControlRequest): + control_reqs.append(req_dict[0]) + else: + max_occupied_batch_index = int(bsz) + req_dicts.extend(req_dict) + + # todo: run control request async + if len(control_reqs) > 0: + logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.") + for control_req in control_reqs: + if self.parallel_config.use_ep: + self.cached_control_reqs.append(control_req) + logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}") else: - max_occupied_batch_index = int(bsz) - req_dicts.extend(req_dict) - - # todo: run control request async - if len(control_reqs) > 0: - logger.info(f"Rank: {self.local_rank} received {len(control_reqs)} control request.") - for control_req in control_reqs: - if self.parallel_config.use_ep: - self.cached_control_reqs.append(control_req) - logger.info(f"Rank: {self.local_rank} cached ep control request: {control_req}") - else: - self.run_control_method(control_req) - self._tp_barrier_wait() if tp_size > 1 else None + self.run_control_method(control_req) + self._tp_barrier_wait() if tp_size > 1 else None if len(req_dicts) > 0: # Count prefill requests in current batch @@ -628,12 +621,6 @@ def event_loop_normal(self) -> None: # Process prefill inputs self.worker.preprocess_new_task(req_dicts, max_occupied_batch_index) - else: - if self.scheduler_config.splitwise_role == "prefill": - if tp_size > 1: - # Synchronize the signal for other workers - self._tp_barrier_wait() - continue # Let the ep group run control method synchronically if envs.FD_ENABLE_V1_UPDATE_WEIGHTS and self.parallel_config.use_ep: @@ -648,7 +635,6 @@ def event_loop_normal(self) -> None: and not self.worker.model_runner.not_need_stop() ): self._tp_barrier_wait() if tp_size > 1 else None - self.engine_forward_signal.value[0] = 0 time.sleep(0.001) continue @@ -673,9 +659,6 @@ def event_loop_normal(self) -> None: if not envs.ENABLE_V1_KVCACHE_SCHEDULER: self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill() logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s") - # run eplb - self._run_eplb(tp_rank) - self.engine_forward_signal.value[0] = 0 if ( not self.parallel_config.use_ep @@ -874,6 +857,12 @@ def parse_args(): default=None, help="Configuration of SpeculativeConfig.", ) + parser.add_argument( + "--enable_flashinfer_allreduce_fusion", + action="store_true", + default=False, + help="Flag to enable all reduce fusion kernel in flashinfer.", + ) parser.add_argument( "--max_num_batched_tokens", type=int, @@ -925,6 +914,11 @@ def parse_args(): action="store_true", help="enable chunked moe", ) + parser.add_argument( + "--enable_moe_scores_elementwise_fuse", + action="store_true", + help="enable fused elementwise cast in get_moe_scores", + ) parser.add_argument( "--chunked_moe_size", type=int, @@ -1129,6 +1123,16 @@ def parse_args(): help="Maximum tokens per item in mm input.", ) + parser.add_argument( + "--enable_keep_sampling_mask", + "--enable-keep-sampling-mask", + action="store_true", + help=( + "Enable output of keep_sampling_mask as sparse vocab index list per token step " + "(Non-MTP: List[int]; MTP: List[List[int]])." + ), + ) + parser.add_argument( "--num_cpu_blocks", type=int, @@ -1332,7 +1336,7 @@ def run_worker_proc() -> None: # Enable batch-invariant mode for deterministic inference. # This must happen AFTER worker creation but BEFORE model loading, - # because enable_batch_invariant_mode() calls paddle.compat.enable_torch_proxy() + # because enable_batch_invariant_mode() calls paddle.enable_compat() # which makes torch appear available via proxy. If called before worker creation, # the gpu_model_runner import chain (ernie4_5_vl_processor → paddleformers → # transformers) will fail when transformers tries to query torch metadata. diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 1446257d3ae..ea83d7d0141 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -58,7 +58,7 @@ from fastdeploy.model_executor.xpu_pre_and_post_process import ( step_xpu, xpu_post_process_normal, - xpu_post_process_specualate, + xpu_post_process_speculate, xpu_pre_process, xpu_process_output, ) @@ -97,7 +97,7 @@ def __init__( local_rank: int, ): super().__init__(fd_config=fd_config, device=device) - self.enable_mm = self.model_config.enable_mm + self.enable_mm = self.fd_config.enable_mm_runtime self.rank = rank self.local_rank = local_rank self.device_id = device_id @@ -1251,6 +1251,8 @@ def initialize_kv_cache(self, profile: bool = False) -> None: # Check if gpu runner needs to create kv cache # 1. During profiling, it creates its own kv cache. # 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled. + # Note: even when CPU cache (num_cpu_blocks > 0) is enabled, GPU runner still + # creates GPU cache tensors; cache transfer manager handles CPU<->GPU swap. create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed" if not create_cache_tensor: logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}") @@ -1635,7 +1637,7 @@ class at the server level, which is too granular for ModelRunner. if self.speculative_decoding: # base model post process - xpu_post_process_specualate( + xpu_post_process_speculate( sampler_output, model_output_data, self.share_inputs, diff --git a/requirements.txt b/requirements.txt index a6a7b6619c9..9a7bb3b3613 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,7 +46,6 @@ setproctitle aistudio_sdk p2pstore py-cpuinfo -flashinfer-python-paddle -flash_mask @ https://paddle-qa.bj.bcebos.com/ernie/flash_mask-4.0.post20260128-py3-none-any.whl -arctic_inference @ https://paddle-qa.bj.bcebos.com/ernie/arctic_inference-0.1.3-cp310-cp310-linux_x86_64.whl +flashinfer-python-paddle @ https://xly-devops.bj.bcebos.com/flashinfer/flashinfer_python_paddle-0.4.1.3-py3-none-any.whl +flash_mask @ https://xly-devops.bj.bcebos.com/flashmask/flash_mask-4.0.0%2Bg4c84f74-py3-none-any.whl transformers>=4.55.1,<5.0.0 diff --git a/scripts/CheckPRTemplate.py b/scripts/CheckPRTemplate.py index c51d64b9bb6..5cb4536ab88 100644 --- a/scripts/CheckPRTemplate.py +++ b/scripts/CheckPRTemplate.py @@ -28,7 +28,6 @@ "## Modifications", "## Usage or Command", "## Accuracy Tests", - "## Checklist", ] } } @@ -65,27 +64,6 @@ def check_section_content(body, section_titles): return results -def parse_checklist(section_content): - """ - Parse a checklist section and return dict of items with checked status. - Example return: - { - 'Add at least a tag in the PR title.': False, - 'Format your code, run `pre-commit` before commit.': True, - ... - } - """ - items = {} - lines = section_content.splitlines() - for line in lines: - match = re.match(r"- \[( |x|X)\] (.+)", line) - if match: - checked = match.group(1).lower() == "x" - item_text = match.group(2).strip() - items[item_text] = checked - return items - - def check_pr_template(repo, body): """Check whether a PR description follows the expected template.""" body = remove_comments(body) @@ -108,21 +86,11 @@ def check_pr_template(repo, body): else: messages.append("❌ Missing sections: {}. Please complete them.".format(", ".join(missing))) - # Check Checklist items if present - checklist_content = results.get("## Checklist", "") - if checklist_content: - checklist_items = parse_checklist(checklist_content) - unchecked = [item for item, checked in checklist_items.items() if not checked] - if unchecked: - messages.append("❌ The following checklist items are not completed:") - for item in unchecked: - messages.append(f" - [ ] {item}") - if messages: messages.append( "\n💡 **Tips for fixing:**\n" "1. Each PR must follow the standard FastDeploy PR template.\n" - "2. Ensure every section (Motivation, Modifications, Usage, Accuracy Tests, Checklist) " + "2. Ensure every section (Motivation, Modifications, Usage, Accuracy Tests) " "is clearly filled with relevant details.\n" "3. You can refer to the official PR example: " "https://github.com/PaddlePaddle/FastDeploy/blob/develop/.github/pull_request_template.md\n" diff --git a/scripts/check_approval.sh b/scripts/check_approval.sh index db1e8a3e225..899019f5a24 100644 --- a/scripts/check_approval.sh +++ b/scripts/check_approval.sh @@ -40,7 +40,7 @@ function add_failed(){ } -HAS_CUSTOM_REGISTRER=`git diff -U0 upstream/$BRANCH | grep '^\+' | grep -zoE "PD_BUILD_(STATIC_)?OP" || true` +HAS_CUSTOM_REGISTRER=`git diff --merge-base -U0 upstream/$BRANCH | grep '^\+' | grep -zoE "PD_BUILD_(STATIC_)?OP" || true` if [ ${HAS_CUSTOM_REGISTRER} ] && [ "${PR_ID}" != "" ]; then echo_line1="You must have one FastDeploy RD (qingqing01(dangqingqing), Jiang-Jia-Jun(jiangjiajun), heavengate(dengkaipeng)) approval for adding custom op.\n" echo_line2="You must have one PaddlePaddle RD (jeff41404(gaoxiang), yongqiangma(mayongqiang)) approval for adding custom op.\n" @@ -52,7 +52,7 @@ WORKER_OR_CONFIG_LIST=( "fastdeploy/model_executor/graph_optimization" ) -HAS_WORKER_OR_CONFIG_MODIFY=`git diff upstream/$BRANCH --name-only | grep -E $(printf -- "-e %s " "${WORKER_OR_CONFIG_LIST[@]}") || true` +HAS_WORKER_OR_CONFIG_MODIFY=`git diff --merge-base upstream/$BRANCH --name-only | grep -E $(printf -- "-e %s " "${WORKER_OR_CONFIG_LIST[@]}") || true` if [ "${HAS_WORKER_OR_CONFIG_MODIFY}" != "" ] && [ "${PR_ID}" != "" ]; then echo_line1="You must have one FastDeploy RD gongshaotian(gongshaotian) approval for modifing [$(IFS=', '; echo "${WORKER_OR_CONFIG_LIST[*]}")]." check_approval "$echo_line1" 1 gongshaotian @@ -63,7 +63,7 @@ SPECULATIVE_DECODING_LIST=( "custom_ops/gpu_ops/speculate_decoding" ) -HAS_SPECULATIVE_DECODING_MODIFY=`git diff upstream/$BRANCH --name-only | grep -E $(printf -- "-e %s " "${SPECULATIVE_DECODING_LIST[@]}") || true` +HAS_SPECULATIVE_DECODING_MODIFY=`git diff --merge-base upstream/$BRANCH --name-only | grep -E $(printf -- "-e %s " "${SPECULATIVE_DECODING_LIST[@]}") || true` if [ "${HAS_SPECULATIVE_DECODING_MODIFY}" != "" ] && [ "${PR_ID}" != "" ]; then echo_line1="You must have one FastDeploy RD (freeliuzc(liuzichang01), Deleter-D(wangyanpeng04)) approval for modifing [$(IFS=', '; echo "${SPECULATIVE_DECODING_LIST[*]}")]." check_approval "$echo_line1" 1 freeliuzc Deleter-D @@ -71,7 +71,7 @@ fi ENV_FILE="fastdeploy/envs.py" -HAS_ENV_MODIFY=$(git diff upstream/$BRANCH --name-only | grep -E "^${ENV_FILE}$" || true) +HAS_ENV_MODIFY=$(git diff --merge-base upstream/$BRANCH --name-only | grep -E "^${ENV_FILE}$" || true) if [ "${HAS_ENV_MODIFY}" != "" ] && [ "${PR_ID}" != "" ]; then echo_line1="You must have one FastDeploy RD (Jiang-Jia-Jun(jiangjiajun), yuanlehome(liuyuanle), rainyfly(chenjian26), Wanglongzhi2001(wanglongzhi)) approval for modifying [${ENV_FILE}]." check_approval "$echo_line1" 1 Jiang-Jia-Jun yuanlehome rainyfly Wanglongzhi2001 diff --git a/scripts/coverage_run.sh b/scripts/coverage_run.sh index 7d001ea83c1..d3b641f1878 100644 --- a/scripts/coverage_run.sh +++ b/scripts/coverage_run.sh @@ -47,11 +47,13 @@ classify_tests() { } # ============================================================ -# Run Test With Logging +# Run Test With Logging (with retry for OOM/Kill) # ============================================================ run_test_with_logging() { local test_file=$1 local log_prefix=$2 + local max_retries=3 # Max retries for OOM/Kill issues + local retry_count=0 local status echo "Running pytest file: $test_file" @@ -67,14 +69,37 @@ run_test_with_logging() { # Set FD_LOG_DIR to isolate logs for each test export FD_LOG_DIR="$isolated_log_dir" - # Run test - timeout 600 python -m coverage run -m pytest -c ${PYTEST_INI} "$test_file" -vv -s - status=$? + # Retry loop for OOM/Kill issues (only handle "Killed" / SIGKILL) + while [ $retry_count -le $max_retries ]; do + if [ $retry_count -gt 0 ]; then + echo "" + echo "==================== Retrying (${retry_count}/${max_retries}) ====================" + echo "Previous attempt was Killed, retrying..." + # Clean up before retry + sleep 5 # Wait a bit to let resources be released + fi + + # Run test + timeout 600 python -m coverage run -m pytest -c ${PYTEST_INI} "$test_file" -vv -s + status=$? + + # Exit code 137 = SIGKILL (Killed / OOM) + if [ "$status" -eq 137 ] && [ $retry_count -lt $max_retries ]; then + retry_count=$((retry_count + 1)) + continue + fi + + # Break loop on success or non-Kill error or max retries reached + break + done if [ "$status" -ne 0 ]; then echo "$test_file" >> "$log_prefix" echo "" echo "==================== Test Failed: $test_file ====================" + if [ $retry_count -gt 0 ]; then + echo "Total attempts: $((retry_count + 1))" + fi # Use isolated log directory for this test if [ -d "$isolated_log_dir" ]; then @@ -94,7 +119,7 @@ run_test_with_logging() { fi echo ">>> grep error in ${isolated_log_dir}" - grep -Rni --color=auto "error" "${isolated_log_dir}" || true + grep -Rni --color=auto "error" "${isolated_log_dir}" --exclude="pytest_*_error.log" || true fi # print all server logs diff --git a/scripts/run_golang_router.sh b/scripts/run_golang_router.sh index 66578d267d9..85e204bf72c 100644 --- a/scripts/run_golang_router.sh +++ b/scripts/run_golang_router.sh @@ -54,7 +54,7 @@ for test_file in "${test_files[@]}"; do fi echo ">>> grep error in ${log_dir}" - grep -Rni --color=auto "error" "${log_dir}" || true + grep -Rni --color=auto "error" "${log_dir}" --exclude="pytest_*_error.log" || true fi done diff --git a/scripts/run_gpu_4cards.sh b/scripts/run_gpu_4cards.sh index 719ec19255c..9874302cf86 100644 --- a/scripts/run_gpu_4cards.sh +++ b/scripts/run_gpu_4cards.sh @@ -44,7 +44,7 @@ for test_file in "${test_files[@]}"; do if [ -d "${REPO_ROOT}/log" ]; then echo ">>> grep error in ${REPO_ROOT}/log/" - grep -Rni --color=auto "error" "${REPO_ROOT}/log/" || true + grep -Rni --color=auto "error" "${REPO_ROOT}/log/" --exclude="pytest_*_error.log" || true else echo "${REPO_ROOT}/log directory not found" fi diff --git a/scripts/run_pre_ce.sh b/scripts/run_pre_ce.sh index 8eafe280346..f16eea8dd82 100644 --- a/scripts/run_pre_ce.sh +++ b/scripts/run_pre_ce.sh @@ -7,7 +7,18 @@ python -m pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/p python -m pip install -r requirements.txt python -m pip install jsonschema aistudio_sdk==0.3.5 -python -m pip install xgrammar==0.1.19 torch==2.6.0 +# Use prebuilt wheel files to install xgrammar==0.1.19 triton==3.4.0 nvidia_nccl_cu12==2.27.3 and torch==2.8.0 specifically for the CI environment +python -m pip install --no-deps \ + https://paddle-qa.bj.bcebos.com/FastDeploy/torch-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl \ + https://paddle-qa.bj.bcebos.com/FastDeploy/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl \ + https://paddle-qa.bj.bcebos.com/FastDeploy/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl \ + https://paddle-qa.bj.bcebos.com/FastDeploy/xgrammar-0.1.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + +# install runtime dependencies for torch and xgrammar +python -m pip install pydantic sentencepiece tiktoken ninja filelock sympy jinja2 fsspec + +# fix tests/ci_use/Prefix_Caching_Swap/test_vl_prefix_caching_swap.py (requires new pyarrow memory behavior) +python -m pip install pyarrow==24.0.0 failed_files=() run_path="$DIR/../tests/ci_use/" @@ -34,7 +45,7 @@ for subdir in "$run_path"*/; do if [ $exit_code -ne 0 ]; then if [ -d "${subdir%/}/log" ]; then echo ">>> grep error in ${subdir%/}/log/" - grep -Rni --color=auto "error" "${subdir%/}/log/" || true + grep -Rni --color=auto "error" "${subdir%/}/log/" --exclude="pytest_*_error.log" || true else echo "${subdir%/}/log directory not found" fi diff --git a/setup.py b/setup.py index 55e78125f5a..4c4e24f950e 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ from pathlib import Path import paddle +from packaging import tags from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext from setuptools.command.install import install @@ -42,16 +43,17 @@ class CustomBdistWheel(bdist_wheel): - """Custom wheel builder for pure Python packages.""" + """Custom wheel builder.""" def finalize_options(self): - """Configure wheel as pure Python and platform-independent.""" + """Configure wheel as {python tag}-{abi tag}-{platform tag}.""" super().finalize_options() - self.root_is_pure = True - self.python_tag = "py3" - self.abi_tag = "none" + tag = next(tags.sys_tags()) + self.root_is_pure = False + self.python_tag = tag.interpreter + self.abi_tag = tag.abi self.plat_name_supplied = True - self.plat_name = "any" + self.plat_name = tag.platform class CMakeExtension(Extension): @@ -251,7 +253,7 @@ def get_name(): cmdclass_dict = {"bdist_wheel": CustomBdistWheel} cmdclass_dict["build_ext"] = CMakeBuild -FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.5.0-dev") +FASTDEPLOY_VERSION = os.environ.get("FASTDEPLOY_VERSION", "2.6.0") cmdclass_dict["build_optl"] = PostInstallCommand diff --git a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py index 121e74ee4b9..54ce40ca5d4 100644 --- a/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py +++ b/tests/batch_invariant/test_rmsnorm_layer_batch_invariant.py @@ -31,6 +31,7 @@ def _make_minimal_rmsnorm(hidden_size, eps=1e-5, dtype="float32"): layer.bias = None layer.split_x = False layer.allgather_out = False + layer.enable_all_reduce_fusion = False return layer diff --git a/tests/cache_manager/test_cache_messager.py b/tests/cache_manager/test_cache_messager.py index d053653e658..07ff5054f2a 100644 --- a/tests/cache_manager/test_cache_messager.py +++ b/tests/cache_manager/test_cache_messager.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util +import os import sys import types @@ -19,10 +21,24 @@ import paddle import pytest -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda *args, **kwargs: None) - -from fastdeploy.cache_manager import cache_messager +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None + +# Import the legacy cache_messager module directly from the .py file, +# because the cache_messager/ package shadows it and the legacy +# fallback (cache_messager_legacy) does not exist locally. +_cm_py_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "fastdeploy", + "cache_manager", + "cache_messager.py", +) +_spec = importlib.util.spec_from_file_location( + "fastdeploy.cache_manager.cache_messager_py", + _cm_py_path, +) +cache_messager = importlib.util.module_from_spec(_spec) +_spec.loader.exec_module(cache_messager) class _DummyBarrier: @@ -40,12 +56,10 @@ def __init__(self, cache_info_sequence=None, connect_task_sequence=None, **kwarg self.cache_info_calls = 0 self.connect_task_calls = 0 self.cache_info_barrier = _DummyBarrier() - self.finish_add_cache_task_barrier = _DummyBarrier() self.finish_send_cache_barrier = _DummyBarrier() self.connect_task_barrier = _DummyBarrier() self.connect_task_response_barrier = _DummyBarrier() self.begin_send_cache_barrier = _DummyBarrier() - self.finished_add_cache_task_req_ids = [] self.finished_req_payloads = [] self.connect_task_responses = [] @@ -56,9 +70,6 @@ def get_cache_info(self): self.cache_info_calls += 1 return info - def put_finished_add_cache_task_req(self, req_ids): - self.finished_add_cache_task_req_ids.append(req_ids) - def put_finished_req(self, payload): self.finished_req_payloads.append(payload) @@ -113,6 +124,14 @@ def error(self, msg): self.messages.append(("error", msg)) +class _QueueRecorder: + def __init__(self): + self.items = [] + + def put(self, item): + self.items.append(item) + + class _DummySignalValue: def __init__(self, sequence): self.sequence = list(sequence) @@ -376,10 +395,114 @@ def test_cache_messager_v1_add_cache_task_thread(monkeypatch): } with pytest.raises(SystemExit): messager._add_cache_task_thread() - assert dummy_queue.finished_add_cache_task_req_ids == [["req-2"]] assert messager.cache_info["req-2"]["status"] == "init" +def test_cache_messager_v1_recovers_pending_layer0_signal(monkeypatch): + dummy_queue = _DummyEngineWorkerQueue( + cache_info_sequence=[ + [ + { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + ] + ] + ) + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + messager.cache_info["req-pending"] = { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + messager.pending_layer0_signals[3] = (3, 64) + messager.pending_layer0_signals[4] = (4, 64) + + with pytest.raises(SystemExit): + messager._add_cache_task_thread() + + assert messager.pending_layer0_signals == {4: (4, 64)} + assert messager.cache_prefilled_engine_ids_queue.items == [[(3, 64)]] + + +def test_cache_messager_v1_drops_invalid_pending_layer0_signal(monkeypatch): + dummy_queue = _DummyEngineWorkerQueue( + cache_info_sequence=[ + [ + { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + ] + ] + ) + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", lambda *args, **kwargs: dummy_queue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=True, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + messager.cache_info["req-pending"] = { + "request_id": "req-pending", + "src_block_ids": [0, 1], + "dest_block_ids": [2], + "current_id": 3, + "need_prefill_tokens": 128, + "transfer_protocol": "rdma", + } + messager.pending_layer0_signals[3] = (3, 256) + + with pytest.raises(SystemExit): + messager._add_cache_task_thread() + + assert messager.pending_layer0_signals == {} + assert messager.cache_prefilled_engine_ids_queue.items == [] + + def test_cache_messager_v1_prefill_layerwise_send_cache_thread(monkeypatch): class _OneShotQueue: def __init__(self): @@ -425,10 +548,12 @@ def get(self): } messager.engine_cache_tasks[0] = {"prefilled_layer_idx": 1, "prefilled_token_num": 64} messager.cache_info["req-3"] = messager.idx_cache_task_dict[0] + messager.pending_layer0_signals = {0: (0, 64), 1: (1, 64)} with pytest.raises(SystemExit): messager.prefill_layerwise_send_cache_thread() assert dummy_queue.finished_req_payloads assert dummy_queue.finished_req_payloads[0][0][0] == "req-3" + assert messager.pending_layer0_signals == {1: (1, 64)} def test_cache_messager_v1_handle_connect_task(monkeypatch): @@ -552,13 +677,6 @@ def test_cache_messager_v1_consume_signals(monkeypatch): monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) - class _QueueRecorder: - def __init__(self): - self.items = [] - - def put(self, item): - self.items.append(item) - counter = {"calls": 0} def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): @@ -590,12 +708,57 @@ def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): rdma_port="2222", ) messager.cache_info["req-4"] = {"request_id": "req-4"} + messager.idx_cache_task_dict[2] = {"request_id": "req-4", "current_id": 2} messager.cache_prefilled_engine_ids_queue = _QueueRecorder() with pytest.raises(SystemExit): messager.consume_signals() assert messager.cache_prefilled_engine_ids_queue.items == [[(2, 9)]] +def test_cache_messager_v1_consume_signals_buffers_early_layer0(monkeypatch): + monkeypatch.setattr(cache_messager, "EngineWorkerQueue", _DummyEngineWorkerQueue) + monkeypatch.setattr(cache_messager, "RDMACommManager", _DummyRDMACommManager) + monkeypatch.setattr(cache_messager, "logger", _DummyLogger(), raising=False) + + signals = [(5, 7, 9), (5, 17, 19)] + + def _fake_get_output_kv_signal(kv_signal_data, rank_id, wait_flag): + if not signals: + raise SystemExit + engine_idx, chuck_token_offset, current_seq_len = signals.pop(0) + data = np.full(kv_signal_data.shape, -1, dtype="int32") + data[0] = 1 + data[1] = 0 + data[2] = engine_idx + data[3] = chuck_token_offset + data[4] = current_seq_len + kv_signal_data.set_value(data) + + monkeypatch.setattr(cache_messager, "get_output_kv_signal", _fake_get_output_kv_signal) + gpu_cache_kvs = _build_cache_kvs(dtype="float16", include_value_cache=False, num_layers=1) + messager = cache_messager.CacheMessagerV1( + splitwise_role="mixed", + transfer_protocol="rdma", + pod_ip="0.0.0.0", + engine_worker_queue_port=9000, + local_data_parallel_id=0, + gpu_cache_kvs=gpu_cache_kvs, + rank=0, + nranks=1, + num_layers=1, + gpu_id=0, + block_size=64, + rdma_port="2222", + ) + messager.cache_prefilled_engine_ids_queue = _QueueRecorder() + + with pytest.raises(SystemExit): + messager.consume_signals() + + assert messager.pending_layer0_signals == {5: (5, 36)} + assert messager.cache_prefilled_engine_ids_queue.items == [] + + def test_main_initializes_cache_and_exits(monkeypatch): monkeypatch.setattr(cache_messager, "set_device", lambda device: None) monkeypatch.setattr(cache_messager, "set_data_ipc", lambda tensor, name: None) diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 76419eba8cd..5d2a054761e 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -21,15 +21,9 @@ import paddle -# Ensure paddle exposes compat.enable_torch_proxy for fastdeploy import compatibility. -if not hasattr(paddle, "compat"): - - class _DummyCompat: - @staticmethod - def enable_torch_proxy(scope=None): - return None - - paddle.compat = _DummyCompat() +# Ensure paddle exposes enable_compat for fastdeploy import compatibility. +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None # Add the root directory to Python path so we can import fastdeploy sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) @@ -65,6 +59,8 @@ class Args: kvcache_storage_backend = None write_policy = "write_through" model_path = "test_model" + splitwise_role = "mixed" + routing_replay_config = MagicMock(enable_routing_replay=False) # ========================== @@ -453,7 +449,7 @@ def test_write_back_storage_task_skips_cached_keys(self): self.manager._run_write_back_storage.assert_not_called() self.manager.cache_task_queue.put_transfer_done_signal.assert_called_once_with( - (cache_transfer_manager.CacheStatus.GPU2STORAGE, "5", ["k1", "k2"], []) + (cache_transfer_manager.CacheStatus.GPU2STORAGE, "5", ["k1", "k2"], [0, 1]) ) def test_read_storage_task_no_matches(self): @@ -723,6 +719,7 @@ class LocalArgs(Args): * manager.cache_item_bytes, device_id=manager.device, dp_id=manager.local_data_parallel_id, + splitwise_role=LocalArgs.splitwise_role, ) def test_invalid_write_policy_raises(self): @@ -740,7 +737,7 @@ class LocalArgs(Args): def test_write_back_storage_task_nonzero_rank_no_signal(self): self.manager.cache_task_queue.swap_to_storage_barrier = MagicMock() self.manager.cache_task_queue.put_transfer_done_signal = MagicMock() - self.manager._run_write_back_storage = MagicMock() + self.manager._run_write_back_storage = MagicMock(return_value=1) self.manager.rank = 1 # Mock storage backend to return 0 matches (no keys exist) @@ -764,7 +761,10 @@ def test_write_back_storage_task_nonzero_rank_no_signal(self): [0], 0.1, ) - self.manager.cache_task_queue.put_transfer_done_signal.assert_not_called() + # After the refactor, the done signal is always sent regardless of rank. + self.manager.cache_task_queue.put_transfer_done_signal.assert_called_once_with( + (cache_transfer_manager.CacheStatus.GPU2STORAGE, "9", ["k1"], [0]) + ) def test_get_key_prefix_from_version(self): with patch("fastdeploy.cache_manager.cache_transfer_manager.yaml.safe_load") as mock_load: diff --git a/tests/cache_manager/test_prefix_cache_manager.py b/tests/cache_manager/test_prefix_cache_manager.py index 2ed9a0e6b02..8dd9b5162c4 100644 --- a/tests/cache_manager/test_prefix_cache_manager.py +++ b/tests/cache_manager/test_prefix_cache_manager.py @@ -30,8 +30,8 @@ "ignore:ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead:DeprecationWarning" ) -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda **_: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda **_: None warnings.filterwarnings( "ignore", @@ -1485,6 +1485,7 @@ def test_recv_data_transfer_result_handles_storage_events(self): (CacheStatus.STORAGE2GPU, "pref", ["h1"], [1, 2]), (CacheStatus.STORAGE2GPU, "pref", ["h2"], [1]), (CacheStatus.GPU2STORAGE, "write", ["h3"], [9]), + (CacheStatus.GPU2STORAGE, "write", ["h3"], [9]), ] manager.cache_task_queue = _FakeTransferQueue(payloads) with self.assertRaises(SystemExit): @@ -1544,6 +1545,46 @@ def test_reset_sets_empty_cpu_free_list_when_no_cpu_blocks(self): manager.reset() self.assertEqual(manager.cpu_free_block_list, []) + @patch("fastdeploy.cache_manager.prefix_cache_manager.envs") + def test_free_gpu_block_ids_flushes_cache_gone_with_as_only_flush(self, mock_envs): + """Verify GPU-only eviction sends flush(flush_cache_exists=False) with correct start_write_block_idx.""" + mock_envs.FD_AS_ONLY_FLUSH = True + manager = _create_manager(num_gpu_blocks=4, num_cpu_blocks=0) + manager.kvcache_storage_backend = "attention_store" + + gpu_hash = get_hash_str([9, 10]) + node = BlockNode( + 91, + [9, 10], + gpu_hash, + 3, + 0, + 2, + gpu_hash, + 0, + parent=manager.radix_tree_root, + cache_status=CacheStatus.GPU, + ) + node.shared_count = 0 + node.block_id = 12 + manager.radix_tree_root.children[gpu_hash] = node + manager.node_map[node.node_id] = node + manager.gpu_lru_leaf_heap.append(node) + manager.gpu_lru_leaf_set.add(node) + + captured_tasks = [] + manager.issue_write_back_storage_task = lambda task, is_sync=True: captured_tasks.append(task) + + manager.free_block_ids_async(1) + + self.assertEqual(len(captured_tasks), 1) + flush_task = captured_tasks[0] + self.assertFalse(flush_task.flush_cache_exists) + self.assertEqual(flush_task.keys, [gpu_hash]) + self.assertEqual(flush_task.token_ids, [9, 10]) + self.assertEqual(flush_task.gpu_block_ids, []) + self.assertEqual(flush_task.start_write_block_idx, 2) + if __name__ == "__main__": unittest.main() diff --git a/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py b/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py index 6d8dfac53fd..7c3c6657434 100644 --- a/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py +++ b/tests/ci_use/EB_Lite_with_adapter/test_eblite_serving.py @@ -16,7 +16,6 @@ import queue import shutil import signal -import socket import subprocess import sys import time @@ -30,6 +29,7 @@ sys.path.insert(0, project_root) from ci_use.EB_Lite_with_adapter.zmq_client import LLMControlClient, LLMReqClient +from e2e.utils.serving_utils import clean_ports, is_port_open env = os.environ.copy() @@ -79,88 +79,6 @@ def zmq_control_client(): return client -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False - - -def kill_process_on_port(port: int): - """ - Kill processes that are listening on the given port. - Uses multiple methods to ensure thorough cleanup. - """ - current_pid = os.getpid() - parent_pid = os.getppid() - - # Method 1: Use lsof to find processes - try: - output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() - for pid in output.splitlines(): - pid = int(pid) - if pid in (current_pid, parent_pid): - print(f"Skip killing current process (pid={pid}) on port {port}") - continue - try: - # First try SIGTERM for graceful shutdown - os.kill(pid, signal.SIGTERM) - time.sleep(1) - # Then SIGKILL if still running - os.kill(pid, signal.SIGKILL) - print(f"Killed process on port {port}, pid={pid}") - except ProcessLookupError: - pass # Process already terminated - except subprocess.CalledProcessError: - pass - - # Method 2: Use netstat and fuser as backup - try: - # Find processes using netstat and awk - cmd = f"netstat -tulpn 2>/dev/null | grep :{port} | awk '{{print $7}}' | cut -d'/' -f1" - output = subprocess.check_output(cmd, shell=True).decode().strip() - for pid in output.splitlines(): - if pid and pid.isdigit(): - pid = int(pid) - if pid in (current_pid, parent_pid): - continue - try: - os.kill(pid, signal.SIGKILL) - print(f"Killed process (netstat) on port {port}, pid={pid}") - except ProcessLookupError: - pass - except (subprocess.CalledProcessError, FileNotFoundError): - pass - - # Method 3: Use fuser if available - try: - subprocess.run(f"fuser -k {port}/tcp", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def clean_ports(): - """ - Kill all processes occupying the ports listed in PORTS_TO_CLEAN. - """ - print(f"Cleaning ports: {PORTS_TO_CLEAN}") - for port in PORTS_TO_CLEAN: - kill_process_on_port(port) - - # Double check and retry if ports are still in use - time.sleep(2) - for port in PORTS_TO_CLEAN: - if is_port_open("127.0.0.1", port, timeout=0.1): - print(f"Port {port} still in use, retrying cleanup...") - kill_process_on_port(port) - time.sleep(1) - - @pytest.fixture(scope="session", autouse=True) def setup_and_run_server(): """ @@ -170,8 +88,15 @@ def setup_and_run_server(): - Waits for server port to open (up to 30 seconds) - Tears down server after all tests finish """ + # 清理/dev/shm中的临时文件 + try: + subprocess.run("rm -rf /dev/shm/*", shell=True) + print("Successfully cleaned up /dev/shm.") + except Exception as e: + print(f"Failed to cleanup /dev/shm: {e}") + print("Pre-test port cleanup...") - clean_ports() + clean_ports(PORTS_TO_CLEAN) base_path = os.getenv("MODEL_PATH") if base_path: @@ -236,7 +161,7 @@ def setup_and_run_server(): print("\n===== Post-test server cleanup... =====") try: os.killpg(process.pid, signal.SIGTERM) - clean_ports() + clean_ports(PORTS_TO_CLEAN) print(f"API server (pid={process.pid}) terminated") except Exception as e: print(f"Failed to terminate API server: {e}") diff --git a/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py b/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py index fde03d70ee1..b42799ce066 100644 --- a/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py +++ b/tests/ci_use/Qwen2-7B-Instruct_offline/test_Qwen2-7B-Instruct_offline.py @@ -13,9 +13,7 @@ # limitations under the License. import os -import signal -import socket -import subprocess +import sys import time import traceback @@ -23,21 +21,17 @@ from fastdeploy import LLM, SamplingParams -FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313)) -FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) -MAX_WAIT_SECONDS = 60 - +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..", "..")) +sys.path.insert(0, project_root) +from e2e.utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + clean_ports, +) -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False +MAX_WAIT_SECONDS = 60 def format_chat_prompt(messages): @@ -74,19 +68,15 @@ def llm(model_path): """ Fixture to initialize the LLM model with a given model path """ - try: - output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip() - for pid in output.splitlines(): - os.kill(int(pid), signal.SIGKILL) - print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}") - except subprocess.CalledProcessError: - pass + # Clean ports before starting the test + clean_ports() try: start = time.time() llm = LLM( model=model_path, tensor_parallel_size=1, + port=FD_API_PORT, engine_worker_queue_port=FD_ENGINE_QUEUE_PORT, cache_queue_port=FD_CACHE_QUEUE_PORT, max_model_len=32768, @@ -94,15 +84,7 @@ def llm(model_path): logits_processors=["LogitBiasLogitsProcessor"], ) - # Wait for the port to be open - wait_start = time.time() - while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT): - if time.time() - wait_start > MAX_WAIT_SECONDS: - pytest.fail( - f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}" - ) - time.sleep(1) - + time.sleep(2) print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.") yield llm except Exception: diff --git a/tests/ci_use/metrics/test_metrics.py b/tests/ci_use/metrics/test_metrics.py index 0d5353780f0..a54504c29bd 100644 --- a/tests/ci_use/metrics/test_metrics.py +++ b/tests/ci_use/metrics/test_metrics.py @@ -214,29 +214,28 @@ def test_metrics_with_clear_and_reset(): """ Test the metrics monitoring endpoint. """ - pass # not stable, uncomment after bug fix - # metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + metrics_url = f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" - # async_concurrency(n=10) + async_concurrency(n=10) - # time.sleep(0.3) + time.sleep(0.3) # ===== clear_load_weight ===== - # clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight" - # print("Calling clear_load_weight...") - # r = requests.get(clear_url, timeout=30) - # assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}" - - # metrics = get_metrics_dict(metrics_url) - # running = metrics["fastdeploy:num_requests_running"] - # waiting = metrics["fastdeploy:num_requests_waiting"] - - # print( - # "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):", - # running, - # "waiting:", - # waiting, - # ) + clear_url = f"http://0.0.0.0:{FD_API_PORT}/clear_load_weight" + print("Calling clear_load_weight...") + r = requests.get(clear_url, timeout=30) + assert r.status_code == 200, f"clear_load_weight failed: {r.status_code}" + + metrics = get_metrics_dict(metrics_url) + running = metrics["fastdeploy:num_requests_running"] + waiting = metrics["fastdeploy:num_requests_waiting"] + + print( + "ASSERT after the clear_load_weight operation, the value is 0 (Request interruption stopped inference, and related requests were cleared):", + running, + "waiting:", + waiting, + ) # assert running == 0 and waiting == 0, "Expected both running and waiting to be 0 after clear_load_weight" diff --git a/tests/conftest.py b/tests/conftest.py index 057dc15aebc..5a57414bc69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,23 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. +import glob +import os +import re +import time +from typing import Any, Union + import pytest +from e2e.utils.serving_utils import ( # noqa: E402 + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + clean_ports, +) def pytest_configure(config): + """ + Configure pytest: + - Register custom markers + - Ensure log directory exists + """ config.addinivalue_line("markers", "gpu: mark test as requiring GPU platform") + log_dir = os.environ.get("FD_LOG_DIR", "log") + os.makedirs(log_dir, exist_ok=True) -def pytest_collection_modifyitems(config, items): - """Skip GPU-marked tests when not on a GPU platform. - IMPORTANT: Do NOT import paddle or fastdeploy here. This function runs - during pytest collection (before fork). Importing paddle initializes the - CUDA runtime, which makes forked child processes unable to re-initialize - CUDA (OSError: CUDA error(3), initialization error). +def pytest_collection_modifyitems(config, items): + """ + Skip tests marked with 'gpu' if no GPU device is detected. + + IMPORTANT: + Do NOT import paddle or fastdeploy here. + This hook runs during test collection (before process fork). + Importing CUDA-related libraries will initialize CUDA runtime, + causing forked subprocesses to fail with: + OSError: CUDA error(3), initialization error. """ - import glob - has_gpu = len(glob.glob("/dev/nvidia[0-9]*")) > 0 if has_gpu: @@ -40,18 +61,11 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_marker) -import time -from typing import Any, Union - -from e2e.utils.serving_utils import ( # noqa: E402 - FD_API_PORT, - FD_CACHE_QUEUE_PORT, - FD_ENGINE_QUEUE_PORT, - clean_ports, -) - - class FDRunner: + """ + Wrapper for FastDeploy LLM serving process. + """ + def __init__( self, model_name_or_path: str, @@ -88,7 +102,9 @@ def generate( sampling_params, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: - + """ + Run generation and return token IDs and generated texts. + """ req_outputs = self.llm.generate(prompts, sampling_params=sampling_params, **kwargs) outputs: list[tuple[list[list[int]], list[str]]] = [] for output in req_outputs: @@ -101,6 +117,9 @@ def generate_topp0( max_tokens: int, **kwargs: Any, ) -> list[tuple[list[int], str]]: + """ + Generate outputs with deterministic sampling (top_p=0, temperature=0). + """ from fastdeploy.engine.sampling_params import SamplingParams topp_params = SamplingParams(temperature=0.0, top_p=0, max_tokens=max_tokens) @@ -116,4 +135,33 @@ def __exit__(self, exc_type, exc_value, traceback): @pytest.fixture(scope="session") def fd_runner(): + """Provide FDRunner as a pytest fixture.""" return FDRunner + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + """ + Capture failed test cases and save error logs to FD_LOG_DIR. + + Only logs failures during the test execution phase. + """ + outcome = yield + report = outcome.get_result() + + if report.when == "call" and report.failed: + log_dir = os.environ.get("FD_LOG_DIR", "log") + os.makedirs(log_dir, exist_ok=True) + + case_name = re.sub(r"_+", "_", re.sub(r"[^\w\-.]", "_", item.nodeid.split("::", 1)[-1])).strip("_")[:200] + + error_log_file = os.path.join(log_dir, f"pytest_{case_name}_error.log") + + with open(error_log_file, "w", encoding="utf-8") as f: + f.write(f"Case name: {item.nodeid}\n") + f.write(f"Outcome: {report.outcome}\n") + f.write(f"Duration: {report.duration:.4f}s\n") + f.write("-" * 80 + "\n") + + if report.longrepr: + f.write(str(report.longrepr)) diff --git a/tests/distributed/chunked_moe.py b/tests/distributed/chunked_moe.py index fee1582f3c7..81561f5d829 100644 --- a/tests/distributed/chunked_moe.py +++ b/tests/distributed/chunked_moe.py @@ -85,6 +85,7 @@ class SchedulerConfig: name = "default" splitwise_role = "mixed" max_num_seqs = 2 + max_num_batched_tokens = 2048 parallel_config = ParallelConfig() scheduler_config = SchedulerConfig() @@ -92,6 +93,7 @@ class SchedulerConfig: model_config = MockModelConfig() cache_config = MockCacheConfig() speculative_config = MockSpecaulativeConfig() + enable_mm_runtime = MockModelConfig.enable_mm def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): return 8192 @@ -139,7 +141,7 @@ def setup_model_runner(self): model_runner.model_config = mock_model_config model_runner.cache_config = mock_cache_config model_runner.attn_backends = [MockAttentionBackend()] - model_runner.enable_mm = True + model_runner.enable_mm = mock_fd_config.enable_mm_runtime model_runner.cudagraph_only_prefill = False model_runner.use_cudagraph = False model_runner.speculative_decoding = False diff --git a/tests/e2e/test_Qwen3VL_serving.py b/tests/e2e/test_Qwen3VL_serving.py index 3872b4050ce..bbb053b13dd 100644 --- a/tests/e2e/test_Qwen3VL_serving.py +++ b/tests/e2e/test_Qwen3VL_serving.py @@ -173,7 +173,7 @@ def test_consistency_between_runs(api_url, headers, consistent_payload): content1 = result1["choices"][0]["message"]["content"] # base result - content2 = "视频中手机支架的颜色是黑色的。" + content2 = "视频中手机支架的颜色是黑色。" # Verify that result is same as the base result assert content1.startswith(content2), content1 diff --git a/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py b/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py new file mode 100644 index 00000000000..efe702240e1 --- /dev/null +++ b/tests/e2e/test_ernie_03b_pd_wo_router_v1_rdma_tp1.py @@ -0,0 +1,455 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test splitwise deployment WITHOUT Router: +# use local_scheduler, manually construct disaggregate_info, +# send requests to both Prefill and Decode concurrently. +# ENABLE_V1_KVCACHE_SCHEDULER=1, use rdma to transfer cache. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time +import uuid + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + check_service_health, + clean, +) + +# Ports for PD disaggregation (no router port needed) +FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433)) +FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623)) + +# Prefill uses base ports, Decode uses base+1 +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_CONNECTOR_PORT, + FD_RDMA_PORT, + FD_API_PORT + 1, + FD_ENGINE_QUEUE_PORT + 1, + FD_METRICS_PORT + 1, + FD_CACHE_QUEUE_PORT + 1, + FD_CONNECTOR_PORT + 1, + FD_RDMA_PORT + 1, +] + + +def _build_disaggregate_info() -> dict: + """Build disaggregate_info manually, replicating Router's handle_splitwise_request logic.""" + host_ip = os.getenv("FD_HOST_IP", "127.0.0.1") + return { + "prefill_ip": host_ip, + "decode_ip": host_ip, + "prefill_connector_port": FD_CONNECTOR_PORT, + "decode_connector_port": FD_CONNECTOR_PORT + 1, + "decode_device_ids": ["1"], + "decode_rdma_ports": [FD_RDMA_PORT + 1], + "transfer_protocol": "rdma", + "decode_tp_size": 1, + } + + +def _send_pd_request(payload: dict, timeout: int = 120): + """ + Send request to both Prefill and Decode concurrently, + replicate Router's fan-out forwarding behavior. + Returns the Decode response (same as Router's return_result_url_index=-1). + """ + disaggregate_info = _build_disaggregate_info() + + # Inject disaggregate_info and request_id (same as Router) + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/chat/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions" + + headers = {"Content-Type": "application/json"} + + # For streaming, use requests with stream=True for decode response + if payload.get("stream", False): + # Send to both concurrently (same as Router's fan-out), stream from decode + import concurrent.futures + + def _post_stream(url): + return requests.post(url, headers=headers, json=payload, timeout=timeout, stream=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + prefill_future = executor.submit(_post_stream, prefill_url) + decode_future = executor.submit(_post_stream, decode_url) + # Return decode streaming response immediately + decode_resp = decode_future.result() + # Consume prefill response in background (don't block) + try: + prefill_future.result(timeout=timeout) + except Exception: + pass + return decode_resp + else: + # Non-streaming: send to both, return decode response + import concurrent.futures + + def _post(url): + return requests.post(url, headers=headers, json=payload, timeout=timeout) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + prefill_future = executor.submit(_post, prefill_url) + decode_future = executor.submit(_post, decode_url) + # Wait for both, return decode response + decode_resp = decode_future.result() + # Also check prefill didn't error (but don't block on it) + try: + prefill_future.result(timeout=5) + except Exception: + pass + return decode_resp + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts Prefill and Decode instances WITHOUT Router + - Waits for both to be healthy + - Tears down after all tests finish + """ + print("Pre-test port cleanup...") + clean(PORTS_TO_CLEAN) + + print("log dir clean") + if os.path.exists("log_prefill") and os.path.isdir("log_prefill"): + shutil.rmtree("log_prefill") + if os.path.exists("log_decode") and os.path.isdir("log_decode"): + shutil.rmtree("log_decode") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") + else: + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") + + base_log_dir = os.getenv("FD_LOG_DIR", "log") + + # Prefill instance + print("start prefill...") + env_prefill = os.environ.copy() + env_prefill["CUDA_VISIBLE_DEVICES"] = "0" + env_prefill["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_prefill") + + prefill_log_path = "prefill.log" + prefill_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "8192", + "--splitwise-role", + "prefill", + "--cache-transfer-protocol", + "rdma", + "--rdma-comm-ports", + str(FD_RDMA_PORT), + "--pd-comm-port", + str(FD_CONNECTOR_PORT), + # No --router flag + ] + + with open(prefill_log_path, "w") as logfile: + process_prefill = subprocess.Popen( + prefill_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, + env=env_prefill, + ) + time.sleep(1) + + # Decode instance + print("start decode...") + env_decode = os.environ.copy() + env_decode["CUDA_VISIBLE_DEVICES"] = "1" + env_decode["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_decode") + + decode_log_path = "decode.log" + decode_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT + 1), + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT + 1), + "--metrics-port", + str(FD_METRICS_PORT + 1), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT + 1), + "--max-model-len", + "8192", + "--splitwise-role", + "decode", + "--cache-transfer-protocol", + "rdma", + "--rdma-comm-ports", + str(FD_RDMA_PORT + 1), + "--pd-comm-port", + str(FD_CONNECTOR_PORT + 1), + # No --router flag + ] + + with open(decode_log_path, "w") as logfile: + process_decode = subprocess.Popen( + decode_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, + env=env_decode, + ) + + # Wait up to 300 seconds for both instances to be healthy + for _ in range(60): + prefill_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT}") + decode_healthy = check_service_health(f"http://127.0.0.1:{FD_API_PORT + 1}") + if prefill_healthy and decode_healthy: + print("Prefill and decode servers are both online") + break + time.sleep(5) + else: + print("[TIMEOUT] Servers failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError("Prefill or decode server did not start") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + print(f"Prefill server (pid={process_prefill.pid}) terminated") + print(f"Decode server (pid={process_decode.pid}) terminated") + except Exception as e: + print(f"Failed to terminate server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the Decode API endpoint URL (where final responses come from). + """ + return f"http://127.0.0.1:{FD_API_PORT + 1}/v1/chat/completions" + + +@pytest.fixture +def headers(): + return {"Content-Type": "application/json"} + + +def get_stream_chunks(response): + """Parse streaming response into chunk list.""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"Parse failed: {e}, line: {line}") + else: + print(f"Request failed, status: {response.status_code}") + print("Response:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """Test streaming chat with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = _send_pd_request(payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """Test non-streaming chat with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = _send_pd_request(payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """Test streaming completion (non-chat) with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + # Send to /v1/completions endpoints + disaggregate_info = _build_disaggregate_info() + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions" + headers = {"Content-Type": "application/json"} + + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120) + decode_future = executor.submit( + requests.post, decode_url, json=payload, headers=headers, timeout=120, stream=True + ) + response = decode_future.result() + + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """Test non-streaming completion (non-chat) with usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + # Send to /v1/completions endpoints + disaggregate_info = _build_disaggregate_info() + payload = payload.copy() + payload["disaggregate_info"] = disaggregate_info + if "request_id" not in payload: + payload["request_id"] = f"test-pd-{uuid.uuid4()}" + + prefill_url = f"http://127.0.0.1:{FD_API_PORT}/v1/completions" + decode_url = f"http://127.0.0.1:{FD_API_PORT + 1}/v1/completions" + headers = {"Content-Type": "application/json"} + + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + executor.submit(requests.post, prefill_url, json=payload, headers=headers, timeout=120) + decode_future = executor.submit(requests.post, decode_url, json=payload, headers=headers, timeout=120) + response = decode_future.result().json() + + usage = response["usage"] + result = response["choices"][0]["text"] + print("Decode Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" diff --git a/tests/e2e/test_ernie_21b_mtp.py b/tests/e2e/test_ernie_21b_mtp.py index dc60a213217..0ac4ec789af 100644 --- a/tests/e2e/test_ernie_21b_mtp.py +++ b/tests/e2e/test_ernie_21b_mtp.py @@ -83,6 +83,7 @@ def setup_and_run_server(): json.dumps(speculative_config), "--graph-optimization-config", '{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}', + "--enable-keep-sampling-mask", ] # Start subprocess in new process group @@ -366,3 +367,176 @@ def test_mtp_accept_ratio(api_url): prompt_tokens = chunks[-1]["usage"]["prompt_tokens"] cached_tokens = chunks[-1]["usage"]["prompt_tokens_details"]["cached_tokens"] assert cached_tokens == prompt_tokens // 64 * 64, "cached_tokens数量有问题" + + +def _assert_sampling_mask_format(sampling_mask, max_tokens): + """验证 sampling_mask 字段格式的公共辅助函数。 + + sampling_mask 是 List[List[int]]: + - 外层列表长度 == 生成的 token 数(completion_tokens),对应 MTP 每步可接受多个 token + - 内层列表为保留位置的词汇表索引(int),非空且单调递增 + """ + assert sampling_mask is not None, "sampling_mask 不应为 None" + assert isinstance(sampling_mask, list), "sampling_mask 应为 list" + assert len(sampling_mask) > 0, "sampling_mask 不应为空" + assert len(sampling_mask) <= max_tokens, "sampling_mask 长度不应超过 max_tokens" + + for token_mask in sampling_mask: + assert isinstance(token_mask, list), f"每个 token 的 mask 应为 list,实际: {type(token_mask)}" + assert len(token_mask) > 0, "每个 token 的 mask 不应为空(至少保留采样到的 token)" + for idx in token_mask: + assert isinstance(idx, int), f"mask 中的每个元素应为 int,实际: {type(idx)}" + assert idx >= 0, f"mask 索引不应为负数,实际: {idx}" + + +def test_keep_sampling_mask_stream(api_url): + """测试流式响应中 keep_sampling_mask 功能(MTP 模式)。 + + 验证: + 1. 每个非空 chunk 的 choices[0].sampling_mask 格式为 List[List[int]] + 2. 内层列表包含词汇表保留位置的索引,非空且单调递增 + 3. 最终 sampling_mask 总长度等于 completion_tokens + """ + max_tokens = 20 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 0.9, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用一句话介绍Python语言。"}, + ], + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + + assert len(chunks) > 1, "流式响应应包含至少两个 chunk" + + all_sampling_masks = [] + for chunk in chunks[:-1]: # 最后一个 chunk 是 usage-only + choice = chunk["choices"][0] + # 仅当 delta 有实际内容时才应携带 sampling_mask(首个 role chunk 内容为空,不含该字段) + has_content = bool(choice.get("delta", {}).get("content")) + mask = choice.get("sampling_mask") + if has_content: + assert mask is not None, f"有内容的 chunk 缺少 sampling_mask 字段: {choice}" + if mask is not None: + assert isinstance(mask, list), f"sampling_mask 应为 list,实际: {type(mask)}" + for token_mask in mask: + assert isinstance(token_mask, list), "每个 token mask 应为 list" + assert len(token_mask) > 0, "每个 token mask 不应为空" + for idx in token_mask: + assert isinstance(idx, int) and idx >= 0, f"mask 索引应为非负 int,实际: {idx}" + all_sampling_masks.extend(mask) + + # 最后一个 chunk 携带 usage 信息 + usage = chunks[-1].get("usage") + if usage: + completion_tokens = usage["completion_tokens"] + assert ( + len(all_sampling_masks) == completion_tokens + ), f"sampling_mask 总长度 {len(all_sampling_masks)} 应等于 completion_tokens {completion_tokens}" + + +def test_keep_sampling_mask_non_stream(api_url): + """测试非流式响应中 keep_sampling_mask 功能(MTP 模式)。 + + 验证: + 1. choices[0].sampling_mask 格式为 List[List[int]] + 2. 长度等于 completion_tokens + 3. 内层列表包含非负递增的词汇表索引 + """ + max_tokens = 20 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 0.9, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请用一句话介绍Python语言。"}, + ], + "max_tokens": max_tokens, + "stream": False, + } + + response = send_request(url=api_url, payload=payload).json() + assert "choices" in response, f"响应缺少 choices 字段: {response}" + choice = response["choices"][0] + assert "sampling_mask" in choice, f"choice 缺少 sampling_mask 字段: {choice}" + + sampling_mask = choice["sampling_mask"] + completion_tokens = response["usage"]["completion_tokens"] + _assert_sampling_mask_format(sampling_mask, max_tokens) + assert ( + len(sampling_mask) == completion_tokens + ), f"sampling_mask 长度 {len(sampling_mask)} 应等于 completion_tokens {completion_tokens}" + + +def test_keep_sampling_mask_top_p_1_stream(api_url): + """测试 top_p=1.0 时流式响应的 sampling_mask(MTP 模式)。 + + top_p=1.0 表示保留全部词汇,每个 token mask 应包含所有词汇表位置。 + 验证 mask 非空且每个内层列表长度 > 1(至少保留多个候选 token)。 + """ + max_tokens = 10 + payload = { + "model": "default", + "temperature": 1.0, + "top_p": 1.0, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "1+1="}, + ], + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + assert len(chunks) > 1, "流式响应应包含至少两个 chunk" + + for chunk in chunks[:-1]: + choice = chunk["choices"][0] + mask = choice.get("sampling_mask") + if mask is not None: + for token_mask in mask: + assert len(token_mask) > 1, "top_p=1.0 时每个 token 的候选集应大于 1" + + +def test_keep_sampling_mask_consistent_with_top_p(api_url): + """对比 top_p=0.1 与 top_p=0.9 时 sampling_mask 的候选集大小(非流式,MTP 模式)。 + + top_p 越小,保留的候选 token 越少,平均 mask 长度应更短。 + """ + max_tokens = 15 + + def get_avg_mask_len(top_p): + payload = { + "model": "default", + "temperature": 1.0, + "top_p": top_p, + "seed": 42, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请列举三种编程语言。"}, + ], + "max_tokens": max_tokens, + "stream": False, + } + resp = send_request(url=api_url, payload=payload).json() + mask = resp["choices"][0].get("sampling_mask") + if not mask: + return 0 + return sum(len(m) for m in mask) / len(mask) + + avg_small = get_avg_mask_len(0.1) + avg_large = get_avg_mask_len(0.9) + assert avg_small <= avg_large, f"top_p=0.1 的平均 mask 长度 ({avg_small:.1f}) 应 <= top_p=0.9 ({avg_large:.1f})" diff --git a/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py b/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py new file mode 100644 index 00000000000..0083d70e769 --- /dev/null +++ b/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py @@ -0,0 +1,381 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean, + is_port_open, +) + +os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN" +os.environ["FLAGS_flash_attn_version"] = "3" +os.environ["USE_DECODE_UNIFIED_ATTENTION"] = "1" + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean() + + print("log dir clean ") + if os.path.exists("log") and os.path.isdir("log"): + shutil.rmtree("log") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle") + else: + model_path = "./ernie-4_5-21b-a3b-bf16-paddle" + mtp_model_path = os.path.join(model_path, "mtp") + speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path} + + log_path = "server.log" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "2", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "32768", + "--max-num-seqs", + "128", + "--quantization", + "wint4", + "--speculative-config", + json.dumps(speculative_config), + "--graph-optimization-config", + '{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}', + ] + + # Start subprocess in new process group + # 清除log目录 + if os.path.exists("log"): + shutil.rmtree("log") + with open(log_path, "w") as logfile: + process = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(300): + if is_port_open("127.0.0.1", FD_API_PORT): + print(f"Server is up on port {FD_API_PORT}") + break + time.sleep(1) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process.pid, signal.SIGTERM) + clean() + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process.pid, signal.SIGTERM) + clean() + print(f"server (pid={process.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +def send_request(url, payload, timeout=60): + """ + 发送请求到指定的URL,并返回响应结果。 + """ + headers = { + "Content-Type": "application/json", + } + + try: + res = requests.post(url, headers=headers, json=payload, timeout=timeout) + print("🟢 接收响应中...\n") + return res + except requests.exceptions.Timeout: + print(f"❌ 请求超时(超过 {timeout} 秒)") + return None + except requests.exceptions.RequestException as e: + print(f"❌ 请求失败:{e}") + return None + + +def get_stream_chunks(response): + """解析流式返回,生成chunk List[dict]""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"解析失败: {e}, 行内容: {line}") + else: + print(f"请求失败,状态码: {response.status_code}") + print("返回内容:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """测试流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Prefill Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """测试非流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """测试流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + # print("Prefill Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """测试非流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["text"] + # print("Prefill Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_mtp_accept_ratio(api_url): + """测试mtp接受率""" + payload = { + "model": "default", + "messages": [ + { + "role": "user", + "content": "国外项目风险管理研究起步较早,理论体系成熟。早期研究集中于保险与金融领域,后逐步扩展至工程项目、" + "公共管理等多领域。在理论层面,COSO《企业风险管理——整合框架》和ISO31000标准为风险管理提供了系统性" + "指导,强调风险识别、评估、应对与监控的全流程管理。风险识别方法包括故障树分析、事件树分析等;风险评估" + "则广泛应用VaR模型、蒙特卡洛模拟等量化工具。应对策略涵盖规避、转移、减轻和接受等,并衍生出风险共享、" + "升级等复杂策略。此外,组织文化、管理层支持等因素对风险管理有效性影响显著。近年来,随着科技发展," + "人工智能、大数据等技术被引入风险管理,推动其向智能化、自动化方向发展。请介绍一下国外关于项目风险管理" + "的文献研究综述,300字以内", + }, + ], + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "temperature": 0, + "seed": 23, + "top_p": 0, + } + + print("fastdeploy answer is :") + + try: + # TODO: 第一次和第二次存在diff,后面正常,暂时多请求一次 + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + for idx, chunk in enumerate(chunks): + print(f"\nchunk[{idx}]:\n{json.dumps(chunk, ensure_ascii=False)}") + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + speculate_metrics = chunks[-2]["choices"][0]["speculate_metrics"] + except Exception as e: + print(f"解析失败: {e}") + print("\nresult:\n", result) + + baseline = ( + "国外项目风险管理研究起步早、体系成熟。" + "早期聚焦保险与金融领域,后拓展至多领域。" + "理论层面,COSO《企业风险管理——整合框架》及ISO31000标准提供系统性指导," + "强调全流程管理。" + "风险识别方法多样,如故障树、事件树分析;" + "评估常用VaR模型、蒙特卡洛模拟等量化工具。" + "应对策略丰富,涵盖规避、转移等基本策略及风险共享、升级等复杂策略。" + "组织文化与管理层支持对风险管理有效性影响大。" + "近年来,科技发展促使人工智能、大数据等融入," + "推动风险管理向智能化、自动化迈进 。" + ) + + baseline_ratio = { + "accepted_tokens": 130, + "rejected_tokens": 20, + "accept_ratio": 0.42307692307692313, + "average_accept_length": 1.7333333333333334, + "accepted_tokens_per_head": [75, 55], + "accept_ratio_per_head": [0.7333333333333333], + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result_2 = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + speculate_metrics_2 = chunks[-2]["choices"][0]["speculate_metrics"] + print("chunks:", chunks[-2]) + print("baseline", speculate_metrics) + print("speculate_metrics_2", speculate_metrics_2) + assert result_2 == baseline, f"与baseline存在diff,result_2: {result}\n baseline: {baseline}" + assert speculate_metrics_2 == baseline_ratio, ( + f"speculate_metrics存在diff," f"speculate_metrics_2: {speculate_metrics_2}\n " f"baseline: {baseline_ratio}" + ) + assert speculate_metrics_2["accept_ratio"] > 0, "accept_ratio异常" + prompt_tokens = chunks[-1]["usage"]["prompt_tokens"] + cached_tokens = chunks[-1]["usage"]["prompt_tokens_details"]["cached_tokens"] + assert cached_tokens == prompt_tokens // 64 * 64, "cached_tokens数量有问题" diff --git a/tests/e2e/test_ernie_21b_mtp_multistep.py b/tests/e2e/test_ernie_21b_mtp_multistep.py index 8c4e3b6bab4..9f84b495f8b 100644 --- a/tests/e2e/test_ernie_21b_mtp_multistep.py +++ b/tests/e2e/test_ernie_21b_mtp_multistep.py @@ -212,11 +212,11 @@ def test_prefix_cache_text(api_url): if os.getenv("BASELINE") == "1": baseline_manager.save("base_21b_step3", result) baseline_manager.save("base_21b_mtp_metrics_step3", speculate_metrics_2) - baseline_manager.save("base_21b_logprobs_step3", logprobs_2) + baseline_manager.save("base_21b_logprobs_step3_new", logprobs_2) baseline_result = baseline_manager.load("base_21b_step3") baseline_mtp_metrics = baseline_manager.load("base_21b_mtp_metrics_step3") - baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3") + baseline_logprobs = baseline_manager.load("base_21b_logprobs_step3_new") assert logprobs == logprobs_2, ( "logprobs 前后不一致\n" diff --git a/tests/e2e/utils/rollout_routing_replay_test_utils.py b/tests/e2e/utils/rollout_routing_replay_test_utils.py index 4186a71649a..7ff1808514c 100644 --- a/tests/e2e/utils/rollout_routing_replay_test_utils.py +++ b/tests/e2e/utils/rollout_routing_replay_test_utils.py @@ -21,8 +21,8 @@ def calculate_routing_ratio(expected_routing: paddle.Tensor, actual_routing: pad print(f"token index {i}:\n expected_routing:{expected_routing[i]}\n actual_routing: {actual_routing[i]}\n") assert ( - expected_routing_length == actual_routing_length - ), f"Routing real lengths do not match. Expected length {expected_routing_length} actual length {actual_routing_length}." + expected_routing_length + ) == actual_routing_length, f"Routing real lengths do not match. Expected length {expected_routing_length} actual length {actual_routing_length}." total_rows, elements_per_row = expected_routing.shape mask1 = paddle.any(expected_routing != -1, axis=1) @@ -156,11 +156,9 @@ def check_routing_replay_chat_completion(openai_client, moe_layer_num: int, mode cur_save_routing_path = f"./R3_tmp/routing_replay_output_{model_name}/" model_path = os.getenv("MODEL_PATH") if model_path: - baseline_path = os.path.join( - model_path, f"R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}" - ) + baseline_path = os.path.join(model_path, f"R3_BaseLine_uint8_0530/routing_replay_output_baseline_{model_name}") else: - baseline_path = f"./R3_BaseLine_dev_uint8_0402/routing_replay_output_baseline_{model_name}" + baseline_path = f"./R3_BaseLine_uint8_0530/routing_replay_output_baseline_{model_name}" stream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_stream") nonstream_baseline_path = os.path.join(baseline_path, "r3_chat_completion_nonstream") diff --git a/tests/e2e/utils/serving_utils.py b/tests/e2e/utils/serving_utils.py index 6dd5e77c9b7..9e47ca177e7 100644 --- a/tests/e2e/utils/serving_utils.py +++ b/tests/e2e/utils/serving_utils.py @@ -98,6 +98,60 @@ def kill_process_on_port(port: int): pass +def kill_process_by_unix_socket( + socket_path: str, + force: bool = True, +): + """ + 根据 unix socket 文件路径杀掉对应进程 + cmd: ss -xlpn | grep /dev/shm/fd_task_queue_8664.sock + Args: + socket_path: 例如 /dev/shm/fd_task_queue_8664.sock + force: + True -> SIGKILL + False -> SIGTERM + Returns: + pid 或 None + """ + try: + output = subprocess.check_output( + ["ss", "-xlpn"], + text=True, + ) + for line in output.splitlines(): + if socket_path not in line: + continue + m = re.search(r"pid=(\d+)", line) + if not m: + continue + pid = int(m.group(1)) + os.kill( + pid, + signal.SIGKILL if force else signal.SIGTERM, + ) + return pid + except Exception: + pass + return None + + +def cleanup_unix_socket(socket_path: str): + if not os.path.exists(socket_path): + return + try: + pid = kill_process_by_unix_socket(socket_path) + print(f"Killed process by unix socket: {socket_path}, pid={pid}") + except Exception as e: + print(f"Failed to kill process by unix socket: {socket_path}, error={e}") + finally: + try: + if os.path.exists(socket_path): + os.remove(socket_path) + print(f"Cleaned unix socket: {socket_path}") + except Exception: + pass + + def clean_ports(ports=None): """ Kill all processes occupying the ports @@ -117,6 +171,11 @@ def clean_ports(ports=None): kill_process_on_port(port) time.sleep(1) + # Clean unix socket, fd_task_queue_*.sock, for FD_ENGINE_TASK_QUEUE_WITH_SHM = 1 + print("Cleaning unix socket") + for port in ports: + cleanup_unix_socket(f"/dev/shm/fd_task_queue_{port}.sock") + def clean(ports=None): """ diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index 551f93babd8..9c6baeae348 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -28,16 +28,10 @@ import numpy as np import paddle -from e2e.utils.serving_utils import clean_ports +from e2e.utils.serving_utils import PORTS_TO_CLEAN, clean_ports -if not hasattr(paddle, "compat"): - - class _PaddleCompat: - @staticmethod - def enable_torch_proxy(scope=None): - return None - - paddle.compat = _PaddleCompat() +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.engine.args_utils import EngineArgs @@ -518,6 +512,21 @@ def _make_cfg(self, **kwargs): engine_worker_queue_port = [engine_worker_queue_port + 21 + i for i in range(dp // nnode)] cache_queue_port = [cache_queue_port + 21 + i for i in range(dp // nnode)] + # Add ports to cleanup list + ports_to_add = [] + if isinstance(engine_worker_queue_port, list): + ports_to_add.extend(engine_worker_queue_port) + else: + ports_to_add.append(engine_worker_queue_port) + if isinstance(cache_queue_port, list): + ports_to_add.extend(cache_queue_port) + else: + ports_to_add.append(cache_queue_port) + + for port in ports_to_add: + if port not in PORTS_TO_CLEAN: + PORTS_TO_CLEAN.append(port) + if kwargs.get("num_gpu_blocks_override") is not None and "kv_cache_ratio" not in kwargs: kwargs["kv_cache_ratio"] = 1 @@ -585,6 +594,7 @@ def test_start_prefill_branch_cache_manager_and_worker_dead(self): eng._process_splitwise_task = lambda: None eng._schedule_request_to_worker = lambda: None eng._schedule_request_to_worker_v1 = lambda: None + eng._prepare_request_v1 = lambda: None started_cache = {} @@ -630,6 +640,7 @@ def test_start_mixed_branch_cache_after_load_and_zmq(self): eng._process_splitwise_task = lambda: None eng._schedule_request_to_worker = lambda: None eng._schedule_request_to_worker_v1 = lambda: None + eng._prepare_request_v1 = lambda: None started_cache = {} @@ -1143,22 +1154,29 @@ def test_control_pause_and_resume_paths(self): eng = self._make_mixed_engine() eng.is_paused = False eng._pause_cond = threading.Condition() - eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False)) eng.resource_manager = Mock( - preempted_all=Mock(return_value=[Request(request_id="r1", prompt_token_ids=[1], prompt_token_ids_len=1)]), - get_real_bsz=Mock(), - wait_worker_inflight_requests_finish=Mock(), + requests={"r1": Mock(output_token_ids=[1, 2, 3])}, + waiting_abort_req_id_set=set(), + to_be_aborted_req_id_set=set(), + add_abort_req_ids=Mock(), log_status=Mock(), cache_manager=Mock(reset=Mock()), - real_bsz=1, ) eng.token_processor = Mock(clear_data=Mock()) - eng.scheduler = Mock(get_inflight_requests=Mock(return_value=[]), reset=Mock()) + mock_scheduler = Mock(reset=Mock()) + mock_scheduler.requests = {} + mock_scheduler.mutex = threading.Lock() + mock_scheduler.responses = {} + mock_scheduler.batch_responses_per_step = [] + eng.scheduler = mock_scheduler eng._send_error_response = Mock() + eng._wait_inflight_drained = Mock() with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", True): eng._control_pause(ControlRequest(request_id="ctrl1", method="pause")) self.assertTrue(eng.is_paused) + eng.resource_manager.add_abort_req_ids.assert_called_once() eng._control_resume(ControlRequest(request_id="ctrl2", method="resume")) self.assertFalse(eng.is_paused) @@ -1385,21 +1403,18 @@ def test_schedule_request_to_worker_v1_mixed_single_iteration(self): task = Request(request_id="v1_r0", prompt_token_ids=[1], prompt_token_ids_len=1) task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[task]), put_results=Mock()) + eng.scheduler = Mock(put_results=Mock()) eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) - eng.resource_manager = self._make_v1_decode_rm(eng, ([], []), with_add_request=True) + eng.resource_manager = self._make_v1_decode_rm(eng, ([task], []), with_add_request=True) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False - eng.resource_manager.add_request.assert_called_once_with(task) + eng.engine_worker_queue.put_tasks.assert_called_once() self._detach_finalizer(eng) def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): @@ -1419,7 +1434,6 @@ def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): eng.scheduler = Mock(get_requests=Mock(return_value=[task]), put_results=Mock()) eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(return_value=[]), ) eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=False) @@ -1431,11 +1445,13 @@ def test_schedule_request_to_worker_v1_prefill_decode_alloc_error_safe(self): try: with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", False), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", + False, + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() finally: eng.running = False @@ -1456,19 +1472,14 @@ def test_schedule_request_to_worker_v1_decode_preempted_and_errors(self): task.task_type = RequestType.PREEMPTED task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.scheduler = Mock(put_results=Mock()) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_x", None), ("rid_y", "bad")])) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1492,18 +1503,13 @@ def test_schedule_request_to_worker_v1_decode_prefill_task_path(self): task.trace_carrier = {} task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.scheduler = Mock(put_results=Mock()) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [])) try: - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() finally: eng.running = False @@ -1525,25 +1531,20 @@ def test_schedule_request_to_worker_v1_error_task_none_skips_send(self): task.trace_carrier = {} task.metrics.scheduler_recv_req_time = time.time() - eng.scheduler = Mock(get_requests=Mock(return_value=[]), put_results=Mock()) - eng.engine_worker_queue = Mock( - exist_tasks=Mock(return_value=False), put_tasks=Mock(), num_tasks=Mock(return_value=0) - ) + eng.scheduler = Mock(put_results=Mock()) + eng.engine_worker_queue = Mock(exist_tasks=Mock(return_value=False), put_tasks=Mock()) eng._send_error_response = Mock() eng.resource_manager = self._make_v1_decode_rm(eng, ([task], [("rid_none", None)])) - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() eng.engine_worker_queue.put_tasks.assert_called_once() eng._send_error_response.assert_not_called() self._detach_finalizer(eng) - def test_schedule_request_to_worker_v1_threadpool_shutdown_breaks(self): + def test_schedule_request_to_worker_v1_no_tasks_sleeps(self): eng = self._make_mixed_engine() self._setup_v1_engine(eng) @@ -1551,17 +1552,7 @@ def test_schedule_request_to_worker_v1_threadpool_shutdown_breaks(self): eng.resource_manager = self._make_v1_decode_rm(eng, ([], [])) - class DummyExecutor: - def __init__(self, max_workers=None): - pass - - def submit(self, fn): - raise RuntimeError("cannot schedule new futures after shutdown") - - with ( - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", DummyExecutor), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), - ): + with patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None): eng._schedule_request_to_worker_v1() self._detach_finalizer(eng) @@ -1584,17 +1575,8 @@ def test_schedule_request_to_worker_v1_prefill_continuous_cache_success(self): eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=False) - calls = {"n": 0} - - def get_finished_add_cache_task_req(): - if calls["n"] == 0: - calls["n"] += 1 - return ["pc_ok"] - return [] - eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(side_effect=get_finished_add_cache_task_req), ) eng.split_connector = Mock( @@ -1604,11 +1586,12 @@ def get_finished_add_cache_task_req(): ) with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() eng.split_connector.send_splitwise_tasks.assert_called() eng.split_connector.send_cache_info_to_messager.assert_called_once() @@ -1636,17 +1619,8 @@ def test_schedule_request_to_worker_v1_prefill_continuous_wait_async_none(self): eng.resource_manager = self._make_v1_prefill_continuous_rm(eng, waiting_async_result=None) - calls = {"n": 0} - - def get_finished_add_cache_task_req(): - if calls["n"] == 0: - calls["n"] += 1 - return ["pc_fail"] - return [] - eng.engine_worker_queue = Mock( exist_tasks=Mock(return_value=False), - get_finished_add_cache_task_req=Mock(side_effect=get_finished_add_cache_task_req), ) eng.split_connector = Mock( @@ -1656,11 +1630,12 @@ def get_finished_add_cache_task_req(): ) with ( - patch("fastdeploy.engine.common_engine.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True), - patch("fastdeploy.engine.common_engine.ThreadPoolExecutor", self._make_dummy_executor(eng)), - patch("fastdeploy.engine.common_engine.time.sleep", lambda *_: None), + patch( + "fastdeploy.engine.common_engine_prepare_mixin.envs.PREFILL_CONTINUOUS_REQUEST_DECODE_RESOURCES", True + ), + patch("fastdeploy.engine.common_engine_prepare_mixin.time.sleep", lambda *_: None), ): - eng._schedule_request_to_worker_v1() + eng._fetch_request_prefill() eng.scheduler.put_results.assert_called_once() eng.resource_manager.pre_recycle_resource.assert_called_once_with("pc_fail") @@ -3536,7 +3511,7 @@ def _fake_sleep(s): self.assertGreaterEqual(call_count[0], 1) self._detach_finalizer(eng) - # ── _control_abort_requests / _wait_abort_complete ─────────────── + # ── _resolve_abort_targets / _build_abort_results ─────────────── def _make_abort_engine(self, splitwise_role="mixed"): """Create an engine wired up for abort tests.""" @@ -3577,175 +3552,21 @@ def _make_fake_request(self, output_token_ids=None): req.metrics.engine_recv_first_token_time = 1000.2 return req - def test_control_abort_requests_not_v1_raises(self): - """abort_requests raises when ENABLE_V1_KVCACHE_SCHEDULER is off.""" - eng = self._make_abort_engine() - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): - with self.assertRaises(Exception) as ctx: - eng._control_abort_requests(control_req) - self.assertIn("only supported", str(ctx.exception)) - self._detach_finalizer(eng) - - def test_control_abort_requests_abort_all(self): - """abort_all=True aborts all requests in resource_manager + scheduler.""" + def test_resolve_abort_targets_abort_all(self): + """abort_all=True returns all requests in resource_manager + scheduler.""" eng = self._make_abort_engine() eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20])} eng.scheduler.requests = {"req-2_0": MagicMock(raw=self._make_fake_request([30]))} - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - - def clear_abort_sets(req_id): - # Simulate immediate abort completion - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(len(result["aborted"]), 2) - self.assertEqual(result["not_found"], []) - ids = {a["request_id"] for a in result["aborted"]} - self.assertEqual(ids, {"req-1_0", "req-2_0"}) - # put_results should have been called (not prefill) - eng.scheduler.put_results.assert_called_once() - self._detach_finalizer(eng) - - def test_control_abort_requests_by_req_ids_with_suffix_match(self): - """req_ids match both exact and _0 suffix.""" - eng = self._make_abort_engine() - eng.resource_manager.requests = { - "req-A_0": self._make_fake_request([1, 2, 3]), - "req-B": self._make_fake_request([4, 5]), - } - - control_req = ControlRequest( - "ctrl-1", - "abort_requests", - { - "abort_all": False, - "req_ids": ["req-A", "req-B", "req-C"], - }, - ) - - def clear_abort_sets(req_id): - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - aborted_ids = {a["request_id"] for a in result["aborted"]} - self.assertIn("req-A_0", aborted_ids) # matched via _0 suffix - self.assertIn("req-B", aborted_ids) # exact match - self.assertEqual(result["not_found"], ["req-C"]) - self._detach_finalizer(eng) - - def test_control_abort_requests_no_match(self): - """No requests found returns empty aborted and all in not_found.""" - eng = self._make_abort_engine() - control_req = ControlRequest( - "ctrl-1", - "abort_requests", - { - "abort_all": False, - "req_ids": ["nonexistent"], - }, - ) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(result["aborted"], []) - self.assertEqual(result["not_found"], ["nonexistent"]) - self._detach_finalizer(eng) - - def test_control_abort_requests_prefill_skips_wait_and_put(self): - """Prefill role skips _wait_abort_complete and put_results.""" - eng = self._make_abort_engine(splitwise_role="prefill") - eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} - - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - eng.resource_manager.add_abort_req_ids = MagicMock() - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(len(result["aborted"]), 1) - eng.scheduler.put_results.assert_not_called() - self._detach_finalizer(eng) - - def test_control_abort_requests_output_token_count(self): - """output_token_count reflects partial_token_ids length.""" - eng = self._make_abort_engine() - eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20, 30, 40, 50])} - - control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) - - def clear_abort_sets(req_id): - eng.resource_manager.waiting_abort_req_id_set.discard(req_id) - - eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) - - with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): - result = eng._control_abort_requests(control_req) - - self.assertEqual(result["aborted"][0]["output_token_count"], 5) - self._detach_finalizer(eng) - - def test_wait_abort_complete_immediate(self): - """_wait_abort_complete returns immediately when all requests already cleaned.""" - eng = self._make_abort_engine() - # Empty abort sets → remaining is empty → returns immediately - eng._wait_abort_complete(["req-1_0"]) - self._detach_finalizer(eng) - - def test_wait_abort_complete_progress(self): - """_wait_abort_complete exits when background thread cleans up.""" - eng = self._make_abort_engine() - eng.resource_manager.waiting_abort_req_id_set = {"req-1_0"} - - call_count = [0] - - def fake_sleep(s): - call_count[0] += 1 - # Simulate background thread cleaning up after first sleep - eng.resource_manager.waiting_abort_req_id_set.discard("req-1_0") - - with patch("fastdeploy.engine.common_engine.time.sleep", fake_sleep): - eng._wait_abort_complete(["req-1_0"]) - - self.assertGreaterEqual(call_count[0], 1) + target = eng._resolve_abort_targets(abort_all=True, req_ids=[]) + self.assertEqual(set(target), {"req-1_0", "req-2_0"}) self._detach_finalizer(eng) - def test_wait_abort_complete_force_cleanup_stuck_in_to_be_aborted(self): - """Stall timeout triggers force cleanup for requests in to_be_aborted_req_id_set.""" + def test_resolve_abort_targets_no_match(self): + """No matching request ids returns empty list.""" eng = self._make_abort_engine() - eng.resource_manager.to_be_aborted_req_id_set = {"req-1_0"} - - def mock_recycle(req_id): - eng.resource_manager.to_be_aborted_req_id_set.discard(req_id) - - eng.resource_manager.recycle_abort_task = MagicMock(side_effect=mock_recycle) - - # Make time.time() advance past stall_timeout - time_values = [100.0, 100.0, 102.0, 102.0, 102.0] - time_idx = [0] - - def fake_time(): - idx = min(time_idx[0], len(time_values) - 1) - time_idx[0] += 1 - return time_values[idx] - - with ( - patch("fastdeploy.engine.common_engine.time.time", fake_time), - patch("fastdeploy.engine.common_engine.time.sleep", lambda s: None), - ): - eng._wait_abort_complete(["req-1_0"], stall_timeout=1) - - eng.resource_manager.recycle_abort_task.assert_called_with("req-1_0") + target = eng._resolve_abort_targets(abort_all=False, req_ids=["nonexistent"]) + self.assertEqual(target, []) self._detach_finalizer(eng) diff --git a/tests/engine/test_engine.py b/tests/engine/test_engine.py index 762db4ea4ed..17de3b32bc2 100644 --- a/tests/engine/test_engine.py +++ b/tests/engine/test_engine.py @@ -68,6 +68,7 @@ def test_stop_profile_returns_true_on_success(self): parallel_config=types.SimpleNamespace(device_ids="0"), scheduler_config=types.SimpleNamespace(splitwise_role="decode"), cache_config=Mock(enable_prefix_caching=False, reset=Mock()), + routing_replay_config=types.SimpleNamespace(enable_routing_replay=False), ) eng.engine = types.SimpleNamespace( start_cache_service=lambda *_: None, diff --git a/tests/engine/test_request.py b/tests/engine/test_request.py index 9a1f0bc31cf..fd9eab17dc7 100644 --- a/tests/engine/test_request.py +++ b/tests/engine/test_request.py @@ -398,6 +398,7 @@ def test_to_dict_basic(self): request.prompt_token_ids_len = 3 request.sampling_params = SamplingParams() request.metrics = RequestMetrics() + request.metrics.prompt_token_ids_len = 3 data = request.to_dict() diff --git a/tests/engine/test_resource_manager_v1.py b/tests/engine/test_resource_manager_v1.py index 23275f29f70..716770294a6 100644 --- a/tests/engine/test_resource_manager_v1.py +++ b/tests/engine/test_resource_manager_v1.py @@ -72,7 +72,7 @@ def test_preempted_all_with_normal_requests(self): req1 = Mock(spec=Request) req1.request_id = "req1" req1.use_extend_tables = False - req1.status = RequestStatus.RUNNING + req1.status = RequestStatus.RUNNING_DECODE req1.block_tables = [1, 2, 3] req1.num_cached_blocks = 0 req1.idx = 0 @@ -80,7 +80,7 @@ def test_preempted_all_with_normal_requests(self): req2 = Mock(spec=Request) req2.request_id = "req2" req2.use_extend_tables = False - req2.status = RequestStatus.RUNNING + req2.status = RequestStatus.RUNNING_DECODE req2.block_tables = [4, 5] req2.num_cached_blocks = 0 req2.idx = 1 diff --git a/tests/engine/test_scheduler_metrics_logger.py b/tests/engine/test_scheduler_metrics_logger.py index c1305a3daa6..cab38350c49 100644 --- a/tests/engine/test_scheduler_metrics_logger.py +++ b/tests/engine/test_scheduler_metrics_logger.py @@ -32,7 +32,7 @@ def test_on_decode_tokens_accumulates(): def test_log_prefill_batch_logs_expected_message(): - logger = SchedulerMetricsLogger(enabled=True, dp_rank=2) + logger = SchedulerMetricsLogger(enabled=True, dp_rank=2, splitwise_role="prefill") logger._logger = mock.Mock() reqs = [ @@ -46,6 +46,7 @@ def test_log_prefill_batch_logs_expected_message(): message = logger._logger.info.call_args[0][0] assert "Prefill batch" in message assert "dp_rank: 2" in message + assert "splitwise_role: prefill" in message assert "#new-seq: 2" in message assert "#new-token: 4" in message assert "#cached-token: 3" in message @@ -54,8 +55,31 @@ def test_log_prefill_batch_logs_expected_message(): assert "#queue-req: 6" in message +def test_log_decode_bootstrap_batch_logs_expected_message(): + logger = SchedulerMetricsLogger(enabled=True, dp_rank=0, splitwise_role="decode") + logger._logger = mock.Mock() + + reqs = [types.SimpleNamespace(prefill_start_index=4, prefill_end_index=5, num_cached_tokens=4)] + + logger.log_decode_bootstrap_batch( + prefill_reqs=reqs, + running_cnt=1, + queue_cnt=0, + tokens_used=5, + token_usage=0.25, + ) + + logger._logger.info.assert_called_once() + message = logger._logger.info.call_args[0][0] + assert "Decode bootstrap batch" in message + assert "splitwise_role: decode" in message + assert "#new-seq: 1" in message + assert "#new-token: 1" in message + assert "#cached-token: 4" in message + + def test_log_decode_batch_computes_throughput(monkeypatch): - logger = SchedulerMetricsLogger(enabled=True, dp_rank=1) + logger = SchedulerMetricsLogger(enabled=True, dp_rank=1, splitwise_role="decode") logger._logger = mock.Mock() logger._decode_batch_count = logger._decode_log_interval - 1 logger._decode_tokens_since_last = 10 @@ -69,6 +93,7 @@ def test_log_decode_batch_computes_throughput(monkeypatch): message = logger._logger.info.call_args[0][0] assert "Decode batch" in message assert "dp_rank: 1" in message + assert "splitwise_role: decode" in message assert "gen throughput (token/s): 5.00" in message assert "#queue-req: 7" in message assert logger._decode_tokens_since_last == 0 @@ -99,3 +124,8 @@ def test_decode_log_interval_non_positive_falls_back_to_default(monkeypatch): monkeypatch.setenv("FD_CONSOLE_DECODE_LOG_INTERVAL", "0") logger = SchedulerMetricsLogger(enabled=True, dp_rank=0) assert logger._decode_log_interval == SchedulerMetricsLogger.DEFAULT_DECODE_LOG_INTERVAL + + +def test_default_splitwise_role_is_mixed(): + logger = SchedulerMetricsLogger(enabled=True, dp_rank=0) + assert logger.splitwise_role == "mixed" diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py index 48704e026b6..301e77489c1 100644 --- a/tests/entrypoints/openai/test_api_server.py +++ b/tests/entrypoints/openai/test_api_server.py @@ -828,44 +828,30 @@ def _mock_abort_control_response(api_server, result, status_code=200): async def test_abort_requests_with_req_ids(): args = _build_args() api_server = _reload_api_server(args) - _mock_abort_control_response( - api_server, - { - "aborted": [{"request_id": "req-1_0", "output_token_count": 10}], - "not_found": ["req-999"], - }, - ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.abort_reqs = AsyncMock(return_value=None) req = MagicMock() req.json = AsyncMock(return_value={"req_ids": ["req-1", "req-999"]}) resp = await api_server.abort_requests(req) assert resp.status_code == 200 - control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] - assert control_req.method == "abort_requests" - assert control_req.args["req_ids"] == ["req-1", "req-999"] - assert control_req.args["abort_all"] is False + call_kwargs = api_server.app.state.engine_client.abort_reqs.await_args.kwargs + assert call_kwargs["req_ids"] == ["req-1", "req-999"] + assert call_kwargs["abort_all"] is False @pytest.mark.asyncio async def test_abort_requests_with_abort_all(): args = _build_args() api_server = _reload_api_server(args) - _mock_abort_control_response( - api_server, - { - "aborted": [ - {"request_id": "req-1_0", "output_token_count": 5}, - {"request_id": "req-2_0", "output_token_count": 12}, - ], - "not_found": [], - }, - ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.abort_reqs = AsyncMock(return_value=None) req = MagicMock() req.json = AsyncMock(return_value={"abort_all": True}) resp = await api_server.abort_requests(req) assert resp.status_code == 200 - control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] - assert control_req.args["abort_all"] is True - assert control_req.args["req_ids"] == [] + call_kwargs = api_server.app.state.engine_client.abort_reqs.await_args.kwargs + assert call_kwargs["abort_all"] is True + assert call_kwargs["req_ids"] == [] @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index d98e79b74f2..bd7b6482b09 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -577,6 +577,7 @@ async def test_create_chat_completion_choice(self): response_processor=mock_response_processor, max_tokens=max_tokens_list[idx], speculate_metrics=None, + sampling_mask_list=None, ) expected = case["expected"] diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index 50410ccf236..db871cc7a73 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -19,7 +19,6 @@ import os import shutil import signal -import socket import subprocess import sys import tempfile @@ -63,124 +62,16 @@ write_local_file, ) -# Read ports from environment variables; use default values if not set -FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) -FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) -FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) -FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) - -# List of ports to clean before and after tests -PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT] - - -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False - - -def _clean_cuda_process(): - """ - Kill processes that are using CUDA devices. - NOTE: Do not call this function directly, use the `clean` function instead. - """ - try: - subprocess.run("fuser -k /dev/nvidia*", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def kill_process_on_port(port: int): - """ - Kill processes that are listening on the given port. - Uses multiple methods to ensure thorough cleanup. - """ - current_pid = os.getpid() - parent_pid = os.getppid() - - # Method 1: Use lsof to find processes - try: - output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() - for pid in output.splitlines(): - pid = int(pid) - if pid in (current_pid, parent_pid): - print(f"Skip killing current process (pid={pid}) on port {port}") - continue - try: - # First try SIGTERM for graceful shutdown - os.kill(pid, signal.SIGTERM) - time.sleep(1) - # Then SIGKILL if still running - os.kill(pid, signal.SIGKILL) - print(f"Killed process on port {port}, pid={pid}") - except ProcessLookupError: - pass # Process already terminated - except subprocess.CalledProcessError: - pass - - # Method 2: Use netstat and fuser as backup - try: - # Find processes using netstat and awk - cmd = f"netstat -tulpn 2>/dev/null | grep :{port} | awk '{{print $7}}' | cut -d'/' -f1" - output = subprocess.check_output(cmd, shell=True).decode().strip() - for pid in output.splitlines(): - if pid and pid.isdigit(): - pid = int(pid) - if pid in (current_pid, parent_pid): - continue - try: - os.kill(pid, signal.SIGKILL) - print(f"Killed process (netstat) on port {port}, pid={pid}") - except ProcessLookupError: - pass - except (subprocess.CalledProcessError, FileNotFoundError): - pass - - # Method 3: Use fuser if available - try: - subprocess.run(f"fuser -k {port}/tcp", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def clean_ports(ports=None): - """ - Kill all processes occupying the ports - """ - if ports is None: - ports = PORTS_TO_CLEAN - - print(f"Cleaning ports: {ports}") - for port in ports: - kill_process_on_port(port) - - # Double check and retry if ports are still in use - time.sleep(2) - for port in ports: - if is_port_open("127.0.0.1", port, timeout=0.1): - print(f"Port {port} still in use, retrying cleanup...") - kill_process_on_port(port) - time.sleep(1) - - -def clean(ports=None): - """ - Clean up resources used during testing. - """ - clean_ports(ports) - - # Clean CUDA devices before and after tests. - # NOTE: It is dangerous to use this flag on development machines, as it may kill other processes - clean_cuda = int(os.getenv("CLEAN_CUDA", "0")) == 1 - if clean_cuda: - _clean_cuda_process() +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) +from e2e.utils.serving_utils import ( + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + clean_ports, +) INPUT_BATCH = """ {"custom_id": "req-00001", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "user", "content": "Can you write a short poem? (id=1)"}], "temperature": 0.7, "max_tokens": 200}} diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 1b33405503f..12f20f39eab 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -398,6 +398,7 @@ async def test_create_chat_completion_choice_audio_recover(self): response_processor=response_processor, max_tokens=2, speculate_metrics=None, + sampling_mask_list=None, ) self.assertEqual(choice.finish_reason, "recover_stop") @@ -421,6 +422,7 @@ async def test_create_chat_completion_choice_audio_recover(self): response_processor=response_processor, max_tokens=2, speculate_metrics=None, + sampling_mask_list=None, ) self.assertEqual(choice_length.finish_reason, "length") diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py index 01a68c2380d..0dbda0c35eb 100644 --- a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py @@ -60,6 +60,50 @@ def get_vocab(self): return ErnieX1ToolParser(tokenizer=DummyTokenizer()) + def _simulate_streaming(self, parser, deltas): + """Simulate a multi-step streaming flow. + + Args: + parser: ErnieX1ToolParser instance + deltas: list of delta text strings, each representing one streaming step + + Returns: + list of results from each extract_tool_calls_streaming call + """ + results = [] + previous_text = "" + token_id = 0 + previous_token_ids = [] + + for delta in deltas: + current_text = previous_text + delta + # When delta contains plus more content, use 2 tokens + # so that the parser extracts tool_call_portion (line 163-164) + if "" in delta and delta != "": + n_tokens = 2 + else: + n_tokens = 1 + + delta_token_ids = list(range(token_id + 1, token_id + 1 + n_tokens)) + token_id += n_tokens + current_token_ids = previous_token_ids + delta_token_ids + + result = parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta, + previous_token_ids, + current_token_ids, + delta_token_ids, + self.dummy_request, + ) + results.append(result) + + previous_text = current_text + previous_token_ids = list(current_token_ids) + + return results + # ==================== __init__ tests (lines 60-81) ==================== def test_init_sets_tokens_and_ids(self): @@ -116,6 +160,14 @@ def test_extract_tool_calls_no_arguments(self): self.assertTrue(result.tools_called) self.assertEqual(result.tool_calls[0].function.arguments, "{}") + def test_extract_tool_calls_empty_arguments(self): + """Cover: tool call with explicit empty arguments {}""" + output = '{"name": "fn", "arguments": {}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(result.tool_calls[0].function.name, "fn") + self.assertEqual(result.tool_calls[0].function.arguments, "{}") + def test_extract_tool_calls_nested_arguments(self): """Cover regex with nested braces in arguments""" output = '{"name": "query", "arguments": {"filter": {"age": {"$gt": 18}}}}' @@ -182,38 +234,24 @@ def test_streaming_balanced_counts_text_after_tool(self): def test_streaming_end_token_in_delta(self): """Cover lines 149-156: appears in delta""" parser = self._new_parser() - # First, start a tool call - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # Now stream arguments - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - ', "arguments": {"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Close with end token in delta - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "v"}}', - '"}}', - [1, 10, 20], - [1, 10, 20, 2], - [2], - self.dummy_request, - ) - # Should handle end token - self.assertTrue(result is None or isinstance(result, DeltaMessage)) + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "', # start + name + args key + "v", # args value + '"}}', # close with end token in delta + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args branch, regex extracts '{"k": "v' as arguments_delta + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, '{"k": "v') + # Step 3: end token in delta triggers close handling + # delta before is '"}}', close branch: rindex('}')=2, diff='"}' + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.arguments, '"}') # --- Lines 160-172: new tool call start (cur_start > cur_end and cur_start > prev_start) --- @@ -255,37 +293,29 @@ def test_streaming_new_tool_call_multi_tokens(self): def test_streaming_continue_tool_call_no_name_yet(self): """Cover lines 174-176, 220-222: partial JSON without name yet""" parser = self._new_parser() - # Start tool call - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Continue with partial content, no name parseable yet - result = parser.extract_tool_calls_streaming( - "", - '{"na', - '{"na', - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"na', # partial content, no name yet + ], ) - self.assertIsNone(result) + self.assertIsNone(results[0]) + self.assertIsNone(results[1]) def test_streaming_continue_tool_call_with_name(self): """Cover lines 174-176, 223-235: name becomes available""" parser = self._new_parser() - # Start tool call - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Name appears - result = parser.extract_tool_calls_streaming( - "", - '{"name": "get_weather"', - '{"name": "get_weather"', - [1], - [1, 10], - [10], - self.dummy_request, - ) - self.assertIsNotNone(result) - self.assertEqual(result.tool_calls[0].function.name, "get_weather") + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"name": "get_weather"', # name appears + ], + ) + self.assertIsNone(results[0]) + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.name, "get_weather") self.assertTrue(parser.current_tool_name_sent) # --- Lines 236-237: name not sent and function_name is None --- @@ -293,18 +323,14 @@ def test_streaming_continue_tool_call_with_name(self): def test_streaming_no_function_name(self): """Cover lines 236-237: parsed JSON has no 'name' field""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Send JSON without name field - result = parser.extract_tool_calls_streaming( - "", - '{"arguments": {"k": "v"}}', - '{"arguments": {"k": "v"}}', - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + '{"arguments": {"k": "v"}}', # JSON without name field + ], ) - self.assertIsNone(result) + self.assertIsNone(results[1]) # --- Lines 178-200: closing branch (cur_start == cur_end, end >= prev_end) --- @@ -333,9 +359,9 @@ def test_streaming_close_with_remaining_diff(self): parser.streamed_args_for_tool = [""] parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] result = parser.extract_tool_calls_streaming( - '{"name":"fn","arguments":{"k":"v"}}', + '{"name":"fn","arguments":{"k":"v"', '{"name":"fn","arguments":{"k":"v"}}', - '"}}', + "}}", [1, 10], [1, 10, 2], [2], @@ -343,9 +369,14 @@ def test_streaming_close_with_remaining_diff(self): ) self.assertIsNotNone(result) self.assertIsNotNone(result.tool_calls) + self.assertEqual(result.tool_calls[0].function.arguments, "}") - def test_streaming_close_with_diff_no_end_marker(self): - """Cover lines 184-185: close with arguments but no '"}' in delta_text""" + def test_streaming_text_after_completed_tool_call(self): + """Cover lines 143-147: text content after a completed tool call. + + When start==end counts, prev_end==cur_end, and end_token not in delta, + the parser treats delta as regular text content. + """ parser = self._new_parser() parser.current_tool_id = 0 parser.current_tool_name_sent = True @@ -353,7 +384,7 @@ def test_streaming_close_with_diff_no_end_marker(self): parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] # Simulate end token in delta but without '"}' pattern # We need cur_start==cur_end and cur_end >= prev_end, and end_token NOT in delta - # so that we enter the elif at 178 + # so that we enter the text-content branch at line 143-147 result = parser.extract_tool_calls_streaming( '{"name":"fn","arguments":{"k":"v"}}', '{"name":"fn","arguments":{"k":"v"}} text', @@ -363,8 +394,9 @@ def test_streaming_close_with_diff_no_end_marker(self): [30], self.dummy_request, ) - # balanced counts, prev_end==cur_end, end not in delta -> returns content (line 147) - self.assertIsInstance(result, DeltaMessage) + # balanced counts, prev_end==cur_end, end not in delta -> returns content (line 149) + self.assertIsNotNone(result) + self.assertEqual(result.content, " text") def test_streaming_close_no_arguments(self): """Cover lines 182-183: close branch where prev arguments is None/empty""" @@ -382,8 +414,126 @@ def test_streaming_close_no_arguments(self): [2], self.dummy_request, ) - # diff is None (no arguments), so falls through to partial_json_parser - self.assertTrue(result is None or isinstance(result, DeltaMessage)) + # diff is None (no arguments key in prev), falls through to partial_json_parser + # parses complete JSON, cur_args=None, prev_args=None -> no-args -> delta=None + self.assertIsNone(result) + + def test_streaming_close_with_empty_dict_arguments(self): + """Regression: close branch must handle arguments={} (empty dict). + + Before fix, `if diff:` was False for empty dict {}, so the close + logic was skipped. After fix, `if diff is not None:` correctly + enters the branch. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # start + name + args key + "{}", # empty dict value + "}", # outer close brace + "", # end token + ], + ) + # Step 1: name sent + # Step 2: first-args, cur_args={} is not None, prev_args=None + # Without fix: not {} == True -> no-args branch -> returns None + # With fix: enters first-args -> streams "{}" -> DeltaMessage + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[1].tool_calls) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + + def test_streaming_empty_arguments_with_outer_brace_in_same_token(self): + """Regression: when arguments={} and outer } arrive in the same token '{}}', + regex (.*) over-captures the outer brace, producing '{}}'. + + Real production data showed arguments='{}}}' for get_default_weather + with empty arguments. This test reproduces that exact scenario. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "get_default_weather", "arguments": ', # start + name + args key + "{}}", # empty args + outer close brace in same token + "", # end token + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "get_default_weather") + # Step 2: first-args branch, tool_call_portion is complete JSON + # regex (.*) captures '{}}' but fix strips outer '}' -> '{}' + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + # Step 3: end token, close branch + # diff = prev_arguments = {} (not None), delta_text = '' (empty after split) + # '}' not in '' -> returns None + self.assertIsNone(results[2]) + + def test_streaming_close_with_number_ending_arguments(self): + """Regression: close branch must flush remaining args ending with number. + + Before fix, '"}' not in delta was True for numbers, causing return None. + After fix, rindex('}') correctly finds the closing brace. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"count": ', # start + name + args key + "123", # number value + "}}", # close braces + end token + ], + ) + # Step 1: name sent + # Step 2: first-args, streams {"count": 123 + # Step 3: close branch flushes remaining "}" + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"count": 123}') + + def test_streaming_close_with_boolean_ending_arguments(self): + """Regression: close branch must flush remaining args ending with boolean.""" + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"flag": ', # start + args key + "true", # boolean value + "}}", # close + end token + ], + ) + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"flag": true}') + + def test_streaming_close_with_nested_object_ending(self): + """Regression: close branch must flush remaining args ending with nested '}'.""" + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"nested": {"a": ', # start + args key + "1", # nested value + "}}}", # close all + end token + ], + ) + streamed_args = [ + r.tool_calls[0].function.arguments + for r in results + if r is not None and r.tool_calls and r.tool_calls[0].function.arguments is not None + ] + combined = "".join(streamed_args) + self.assertEqual(combined, '{"nested": {"a": 1}}') # --- Lines 202-206: else branch (cur_start < cur_end, edge case) --- @@ -404,23 +554,21 @@ def test_streaming_else_branch(self): def test_streaming_malformed_json(self): """Cover lines 213-215: MalformedJSON from partial parser""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) - # Feed badly formed content - result = parser.extract_tool_calls_streaming( - "", - "{{{", - "{{{", - [1], - [1, 10], - [10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + "", # start tool call + "{{{", # badly formed content + ], ) - self.assertIsNone(result) + self.assertIsNone(results[1]) def test_streaming_json_decode_error(self): """Cover lines 216-218: JSONDecodeError from partial parser""" parser = self._new_parser() - parser.extract_tool_calls_streaming("", "", "", [], [1], [1], self.dummy_request) + # Step 1: start tool call normally + self._simulate_streaming(parser, [""]) + # Step 2: mock partial_json_parser to throw ValueError with patch( "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", side_effect=ValueError("bad json"), @@ -430,8 +578,8 @@ def test_streaming_json_decode_error(self): "bad", "bad", [1], - [1, 10], - [10], + [1, 2], + [2], self.dummy_request, ) self.assertIsNone(result) @@ -469,30 +617,17 @@ def test_streaming_tool_portion_none_with_text(self): def test_streaming_first_arguments_with_regex_match(self): """Cover lines 243-244, 257-286: first arguments appear, regex matches""" parser = self._new_parser() - # Start tool call and send name - parser.extract_tool_calls_streaming( - "", - '{"name": "get_weather"', - '{"name": "get_weather"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # Now stream arguments (first time) - # Key must be complete (closing quote) so partial_json_parser returns truthy arguments. - # delta must be a substring of the regex-extracted arguments portion (after "arguments":). - result = parser.extract_tool_calls_streaming( - '{"name": "get_weather"', - '{"name": "get_weather", "arguments": {"location": "bei', - '"bei', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - self.assertIsNotNone(result) - self.assertIsNotNone(result.tool_calls) + results = self._simulate_streaming( + parser, + [ + '{"name": "get_weather", "arguments": {"location": "', # start + name + args key + "bei", # args value + ], + ) + # Step 1: name sent + # Step 2: first-args, regex finds "bei" in '{"location": "bei' + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, '{"location": "bei') def test_streaming_first_arguments_no_regex_match(self): """Cover lines 266-267: regex doesn't match, fallback to json.dumps""" @@ -522,67 +657,119 @@ def test_streaming_first_arguments_no_regex_match(self): self.assertIsNotNone(result.tool_calls) def test_streaming_first_arguments_delta_not_in_json(self): - """Cover lines 271-272: delta_text not found in cur_arguments_json""" + """Cover lines 275-276: delta_text not found in cur_arguments_json, returns None. + When delta contains the arguments key itself (e.g. ', "arguments": {'), + regex extracts cur_arguments_json='{' but delta ', "arguments": {' is not in '{'. + """ parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # Delta text that doesn't appear in the arguments JSON - result = parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v"}}', - "ZZZZZ", - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - self.assertIsNone(result) + results = self._simulate_streaming( + parser, + [ + '{"name": "fn"', # start + partial name + ', "arguments": {', # delta introduces arguments key + open brace + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args branch, regex extracts cur_arguments_json='{' + # delta_text=', "arguments": {' is NOT in '{' -> returns None + self.assertIsNone(results[1]) # --- Lines 249-251: no cur_arguments and no prev_arguments --- def test_streaming_no_arguments_at_all(self): """Cover lines 249-251: both cur and prev arguments are empty/None""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn"', # start + name + "}", # close JSON, no arguments + ], ) - # Continue with name only, no arguments + # prev_arguments=None, cur_arguments=None -> delta=None + self.assertIsNone(results[1]) + + def test_streaming_empty_dict_arguments_not_skipped(self): + """Regression: arguments={} (empty dict) must not be treated as no arguments. + + Empty dict is falsy in Python (`not {} == True`). Before the fix, + this caused empty arguments to enter the no-arguments branch, + silently dropping them during streaming. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # start + name + args key + "{}", # empty dict value + "}", # outer close brace + ], + ) + # Step 1: name sent + # Step 2: cur_arguments={} (not None), prev_arguments=None + # With fix: enters first-arguments branch -> streams "{}" + # Without fix: not {} == True -> no-arguments branch -> delta=None + self.assertIsNotNone(results[1]) + self.assertIsNotNone(results[1].tool_calls) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + + def test_streaming_empty_dict_prev_arguments_not_reset(self): + """Regression: prev_arguments={} must not be treated as no arguments. + + When prev has {} and cur has a non-empty dict, the code should enter + the both-have-arguments branch, not the first-arguments branch. + + This scenario (arguments growing from {} to non-empty) is hard to + produce naturally, so we build up state through a real flow then + verify the branch behavior with one additional call. + """ + parser = self._new_parser() + # Build up state naturally: prev_tool_call_arr gets arguments={} + self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # name + args key + "{}", # empty dict value + "}", # outer close + ], + ) + # Verify state is correct + self.assertEqual(parser.prev_tool_call_arr[0].get("arguments"), {}) + + # Now test: if more argument data arrives, prev_args={} should be + # treated as "not None" -> enters both-have-arguments branch + # Without fix: not {} == True -> first-arguments branch (wrong) result = parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn"}', - "}", - [1, 10], - [1, 10, 20], - [20], + '{"name": "fn", "arguments": {"k": "v', + '{"name": "fn", "arguments": {"k": "val', + "al", + [1, 2, 3], + [1, 2, 3, 4], + [4], self.dummy_request, ) - # prev_arguments=None, cur_arguments=None -> delta=None - # then prev_tool_call_arr updated and returns delta (which is None) - self.assertIsNone(result) + # both-have-arguments branch: delta_text="al" streamed as arguments + self.assertIsNotNone(result) + self.assertEqual(result.tool_calls[0].function.arguments, "al") # --- Lines 253-255: cur_arguments reset (impossible branch) --- def test_streaming_arguments_reset_mid_call(self): - """Cover lines 253-255: prev has arguments but cur doesn't (impossible case)""" + """Cover lines 253-255: prev has arguments but cur doesn't (impossible case). + + This is an edge case that shouldn't happen in normal flow, but tests + defensive handling when partial parser returns no arguments after + previously having them. + """ parser = self._new_parser() parser.current_tool_id = 0 parser.current_tool_name_sent = True parser.streamed_args_for_tool = [""] + # Simulate state where prev already had arguments parser.prev_tool_call_arr = [{"name": "fn", "arguments": {"k": "v"}}] - # Feed content where cur has no arguments but prev does + # Mock parser to return no arguments (simulating the impossible reset) with patch( "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.partial_json_parser.loads", return_value={"name": "fn"}, @@ -591,9 +778,9 @@ def test_streaming_arguments_reset_mid_call(self): '{"name": "fn", "arguments": {"k": "v"', '{"name": "fn", "arguments": {"k": "v"}', '"}', - [1, 10], - [1, 10, 20], - [20], + [1, 2], + [1, 2, 3], + [3], self.dummy_request, ) self.assertIsNone(result) @@ -603,110 +790,48 @@ def test_streaming_arguments_reset_mid_call(self): def test_streaming_incremental_arguments_incomplete(self): """Cover lines 288-314: both prev and cur have arguments, JSON incomplete""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # First arguments - delta must appear in regex-extracted arguments portion - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - '{"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # More argument tokens (both prev and cur have arguments now) - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "val', - "al", - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - self.assertIsNotNone(result) - self.assertEqual(result.tool_calls[0].function.arguments, "al") + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v', # start + name + first args + "a", # establishes prev_args + "l", # incremental: both-have-args + ], + ) + # Step 1: name sent + # Step 2: first-args branch + # Step 3: both-have-args branch, streams "l" + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.arguments, "l") def test_streaming_incremental_arguments_complete_json(self): """Cover lines 289-305: complete JSON with trailing }""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, - ) - # First arguments - delta must appear in regex-extracted arguments portion - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v', - '{"k": "v', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Complete with closing braces - both prev and cur have arguments - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v', - '{"name": "fn", "arguments": {"k": "v"}}', - '"}}', - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - # is_complete_json=True, delta ends with }, should strip trailing } - # After strip: '"' which is not empty, so returns DeltaMessage - self.assertIsNotNone(result) - self.assertIsInstance(result, DeltaMessage) + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v', # start + name + first args + "a", # establishes prev_args + '"}}', # completes JSON + ], + ) + # Step 3: both-have-args, complete JSON, strips trailing } -> streams '"}' + self.assertIsNotNone(results[2]) + self.assertIsInstance(results[2], DeltaMessage) def test_streaming_incremental_arguments_complete_empty_delta(self): """Cover lines 304-305: complete JSON where delta becomes empty after strip""" parser = self._new_parser() - parser.extract_tool_calls_streaming( - "", - '{"name": "fn"', - '{"name": "fn"', - [], - [1, 10], - [1, 10], - self.dummy_request, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": {"k": "v"', # start + name + first args + "}", # inner close (establishes prev_args) + "}", # outer close: both-have-args, complete, delta stripped to "" + ], ) - # First arguments with proper delta - parser.extract_tool_calls_streaming( - '{"name": "fn"', - '{"name": "fn", "arguments": {"k": "v"}', - '{"k": "v"}', - [1, 10], - [1, 10, 20], - [20], - self.dummy_request, - ) - # Send just the outer closing brace - # tool_call_portion becomes complete JSON, delta="}" stripped to "" -> return None - result = parser.extract_tool_calls_streaming( - '{"name": "fn", "arguments": {"k": "v"}', - '{"name": "fn", "arguments": {"k": "v"}}', - "}", - [1, 10, 20], - [1, 10, 20, 30], - [30], - self.dummy_request, - ) - # is_complete_json=True, delta="}" -> stripped to "" -> return None - self.assertIsNone(result) + # Step 3: is_complete_json=True, delta="}" -> stripped to "" -> return None + self.assertIsNone(results[2]) # --- Lines 316-319: prev_tool_call_arr update branches --- @@ -759,95 +884,71 @@ def test_streaming_general_exception(self): def test_streaming_full_flow(self): """Integration test: simulate a full streaming tool call flow""" parser = self._new_parser() - req = self.dummy_request - - # Step 1: text before tool call - r = parser.extract_tool_calls_streaming("", "thinking", "thinking", [], [], [], req) - self.assertEqual(r.content, "thinking") - - # Step 2: tool_call start token - r = parser.extract_tool_calls_streaming("thinking", "thinking", "", [], [1], [1], req) - self.assertIsNone(r) + results = self._simulate_streaming( + parser, + [ + "thinking", # Step 1: text before tool call + "", # Step 2: tool_call start token + '{"name": "search", "arguments": {"query": "', # Step 3: name + args key + "test", # Step 4: args value + " data", # Step 5: more args + ], + ) + # Step 1: plain text + self.assertEqual(results[0].content, "thinking") + # Step 2: start token -> None + self.assertIsNone(results[1]) + # Step 3: name sent + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.name, "search") + # Step 4: first arguments + self.assertIsNotNone(results[3]) + self.assertEqual(results[3].tool_calls[0].function.arguments, '{"query": "test') + # Step 5: more arguments + self.assertIsNotNone(results[4]) + self.assertEqual(results[4].tool_calls[0].function.arguments, " data") - # Step 3: function name appears - r = parser.extract_tool_calls_streaming( - "thinking", - 'thinking{"name": "search"', - '{"name": "search"', - [1], - [1, 10], - [10], - req, - ) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.name, "search") - - # Step 4: arguments start - delta must appear in regex-extracted arguments portion - r = parser.extract_tool_calls_streaming( - 'thinking{"name": "search"', - 'thinking{"name": "search", "arguments": {"query": "test', - '{"query": "test', - [1, 10], - [1, 10, 20], - [20], - req, - ) - self.assertIsNotNone(r) + def test_streaming_empty_arguments_full_flow(self): + """Integration: streaming tool call with arguments={} must not lose arguments. - # Step 5: more arguments - r = parser.extract_tool_calls_streaming( - 'thinking{"name": "search", "arguments": {"query": "test', - 'thinking{"name": "search", "arguments": {"query": "test data', - " data", - [1, 10, 20], - [1, 10, 20, 30], - [30], - req, - ) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.arguments, " data") + Simulates a complete streaming flow where the tool call has empty + arguments. Verifies the name is sent and arguments are streamed. + """ + parser = self._new_parser() + results = self._simulate_streaming( + parser, + [ + '{"name": "fn", "arguments": ', # Step 1: start + name + args key + "{}", # Step 2: empty dict value + "}", # Step 3: outer close + "", # Step 4: end token + ], + ) + # Step 1: name sent + self.assertIsNotNone(results[0]) + self.assertEqual(results[0].tool_calls[0].function.name, "fn") + # Step 2: first-args with cur_args={}, streams "{}" + self.assertIsNotNone(results[1]) + self.assertEqual(results[1].tool_calls[0].function.arguments, "{}") + # Step 4: close branch, delta_text="" after stripping + # diff={} is not None, but "}" not in "" -> return None + self.assertIsNone(results[2]) + self.assertIsNone(results[3]) def test_streaming_multiple_tool_calls(self): """Integration test: two tool calls in one response""" parser = self._new_parser() - req = self.dummy_request - - # First tool call - parser.extract_tool_calls_streaming( - "", - '{"name": "fn1"', - '{"name": "fn1"', - [], - [1, 10], - [1, 10], - req, - ) - self.assertEqual(parser.current_tool_id, 0) - - # Close first tool - parser.extract_tool_calls_streaming( - '{"name": "fn1"', - '{"name": "fn1"}', - "}", - [1, 10], - [1, 10, 2], - [2], - req, - ) - - # Second tool call - r = parser.extract_tool_calls_streaming( - '{"name": "fn1"}', - '{"name": "fn1"}{"name": "fn2"', - '{"name": "fn2"', - [1, 10, 2], - [1, 10, 2, 1, 20], - [1, 20], - req, + results = self._simulate_streaming( + parser, + [ + '{"name": "fn1"', # First tool: start + name + "}", # Close first tool + '{"name": "fn2"', # Second tool: start + name + ], ) self.assertEqual(parser.current_tool_id, 1) - self.assertIsNotNone(r) - self.assertEqual(r.tool_calls[0].function.name, "fn2") + self.assertIsNotNone(results[2]) + self.assertEqual(results[2].tool_calls[0].function.name, "fn2") if __name__ == "__main__": diff --git a/tests/entrypoints/test_engine_client.py b/tests/entrypoints/test_engine_client.py index 0ed8fbdc033..71ad4b29db9 100644 --- a/tests/entrypoints/test_engine_client.py +++ b/tests/entrypoints/test_engine_client.py @@ -102,6 +102,7 @@ def create_mock_fd_config( mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config.reasoning_parser = None mock_config.tool_parser = None + mock_config.enable_mm_runtime = enable_mm return mock_config @@ -181,6 +182,7 @@ async def asyncSetUp(self): mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config.reasoning_parser = None mock_config.node_rank = 0 + mock_config.enable_mm_runtime = mock_model_config.enable_mm # Create mocks for all the external dependencies mock_input_processor = Mock() @@ -363,6 +365,7 @@ def setUp(self): mock_config.structured_outputs_config = MagicMock() # Add this mock_config.structured_outputs_config.reasoning_parser = None mock_config.tool_parser = None # Add this attribute + mock_config.enable_mm_runtime = mock_model_config.enable_mm # Mock IPCSignal to avoid file system dependencies with patch("fastdeploy.entrypoints.engine_client.IPCSignal") as mock_ipcsignal: @@ -655,6 +658,7 @@ async def test_init_basic_parameters(self): mock_config.structured_outputs_config = Mock() mock_config.structured_outputs_config.reasoning_parser = None mock_config.tool_parser = None + mock_config.enable_mm_runtime = mock_config.model_config.enable_mm client = EngineClient( pid=5678, @@ -1078,6 +1082,7 @@ async def test_init_with_multimodal_prefix_cache(self): mock_config = Mock() mock_config.model_config = mock_model_config + mock_config.enable_mm_runtime = mock_model_config.enable_mm mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False @@ -1131,6 +1136,7 @@ async def test_init_as_worker_node(self): mock_config = Mock() mock_config.model_config = mock_model_config + mock_config.enable_mm_runtime = mock_model_config.enable_mm mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False @@ -1408,6 +1414,7 @@ async def test_init_iluvatar_platform(self): mock_config = Mock() mock_config.model_config = mock_model_config + mock_config.enable_mm_runtime = mock_model_config.enable_mm mock_config.eplb_config = Mock() mock_config.eplb_config.enable_eplb = False diff --git a/tests/entrypoints/test_serving_completion.py b/tests/entrypoints/test_serving_completion.py index 9c2beb678df..9b48b2271a3 100644 --- a/tests/entrypoints/test_serving_completion.py +++ b/tests/entrypoints/test_serving_completion.py @@ -21,6 +21,7 @@ import paddle import fastdeploy.metrics.trace as tracing +from fastdeploy.entrypoints.openai.protocol import CompletionResponse from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion from fastdeploy.utils import ErrorCode, ParameterError from fastdeploy.worker.output import LogprobsLists, LogprobsTensors, SpeculateMetrics @@ -171,7 +172,8 @@ async def test_completion_full_generator_branches(self): ec.connection_manager.get_connection = AsyncMock(return_value=(Mock(), rq)) serving = OpenAIServingCompletion(ec, None, "pid", None, -1) res = await serving.completion_full_generator(_make_request(), 1, "req", 1, "m", [[1, 2]], [["p1", "p2"]], [2]) - self.assertIsNone(res) + self.assertIsNotNone(res) + self.assertIsInstance(res, CompletionResponse) ec.connection_manager.cleanup_request.assert_called_once_with("req") def test_logprobs_helpers(self): diff --git a/tests/eplb/test_eplb_utils.py b/tests/eplb/test_eplb_utils.py index 4b367c2a36d..08d4d89dc0f 100644 --- a/tests/eplb/test_eplb_utils.py +++ b/tests/eplb/test_eplb_utils.py @@ -168,7 +168,6 @@ def setUp(self): cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.num_hidden_layers = 3 @@ -200,7 +199,7 @@ def setUp(self): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, eplb_config=eplb_config, ) diff --git a/tests/eplb/test_experts_manager.py b/tests/eplb/test_experts_manager.py index b736c20f263..15060ea480a 100644 --- a/tests/eplb/test_experts_manager.py +++ b/tests/eplb/test_experts_manager.py @@ -48,7 +48,6 @@ def setUp(self): cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.num_hidden_layers = 3 @@ -80,7 +79,7 @@ def setUp(self): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, eplb_config=eplb_config, ) diff --git a/tests/graph_optimization/test_cuda_graph_recapture.py b/tests/graph_optimization/test_cuda_graph_recapture.py index 1a28c0731b3..902bcf182fd 100644 --- a/tests/graph_optimization/test_cuda_graph_recapture.py +++ b/tests/graph_optimization/test_cuda_graph_recapture.py @@ -91,10 +91,10 @@ def forward_correct(self, ids_remove_padding, forward_meta: ForwardMeta): return sublayer2_output - def clear_grpah_opt_backend(self): + def clear_graph_opt_backend(self): """ """ - self.sublayer1.clear_grpah_opt_backend(fd_config=self.fd_config) - self.sublayer2.clear_grpah_opt_backend(fd_config=self.fd_config) + self.sublayer1.clear_graph_opt_backend(fd_config=self.fd_config) + self.sublayer2.clear_graph_opt_backend(fd_config=self.fd_config) class TestCUDAGrpahRecapture(unittest.TestCase): @@ -152,7 +152,7 @@ def capture_and_replay(self, input_tensor1, forward_meta1): # Destroy print_gpu_memory_use("before destroy", 0) - self.test_model1.clear_grpah_opt_backend() + self.test_model1.clear_graph_opt_backend() print_gpu_memory_use("after destroy", 0) def recapture_and_replay(self, input_tensor1, forward_meta1): @@ -168,7 +168,7 @@ def recapture_and_replay(self, input_tensor1, forward_meta1): # Destroy print_gpu_memory_use("before destroy", 0) - self.test_model1.clear_grpah_opt_backend() + self.test_model1.clear_graph_opt_backend() print_gpu_memory_use("after destroy", 0) diff --git a/tests/input/test_ernie_vl_processor.py b/tests/input/test_ernie_vl_processor.py index 6e4fac00182..c440187667f 100644 --- a/tests/input/test_ernie_vl_processor.py +++ b/tests/input/test_ernie_vl_processor.py @@ -361,14 +361,14 @@ def test_process_response_dict(self): # Test with stream=True processor.process_response_dict_streaming = MagicMock(return_value={"text": "response"}) - response_dict = {"ids": [1, 2, 3]} + response_dict = {"ids": [1, 2, 3], "outputs": [[1, 2, 3]]} result = processor.process_response_dict(response_dict, stream=True) processor.process_response_dict_streaming.assert_called_once() self.assertEqual(result, {"text": "response"}) # Test with stream=False processor.process_response_dict_normal = MagicMock(return_value={"text": "response"}) - response_dict = {"ids": [1, 2, 3]} + response_dict = {"ids": [1, 2, 3], "outputs": [[1, 2, 3]]} result = processor.process_response_dict(response_dict, stream=False) processor.process_response_dict_normal.assert_called_once() self.assertEqual(result, {"text": "response"}) diff --git a/tests/input/test_text_processor.py b/tests/input/test_text_processor.py index 818f4b77d0f..56137633350 100644 --- a/tests/input/test_text_processor.py +++ b/tests/input/test_text_processor.py @@ -418,6 +418,31 @@ def test_process_request_dict_rejects_bad_kwargs(self): with self.assertRaisesRegex(ValueError, "chat_template_kwargs must be a dict"): self.processor.process_request_dict(request) + def test_process_request_dict_completion_token_ids_extend(self): + request = {"prompt": "hi", "completion_token_ids": [10, 11, 12], "temperature": 0, "top_p": 0} + processed = self.processor.process_request_dict(request, max_model_len=20) + # prompt "hi" is tokenized to [2] by DummyTokenizer, then extended with completion_token_ids + self.assertEqual(processed["prompt_token_ids"], [2, 10, 11, 12]) + + def test_process_request_dict_no_completion_token_ids(self): + request = {"prompt": "hi", "temperature": 0, "top_p": 0} + processed = self.processor.process_request_dict(request, max_model_len=20) + # without completion_token_ids, prompt_token_ids should remain as tokenized result + self.assertEqual(processed["prompt_token_ids"], [2]) + + def test_process_request_dict_empty_completion_token_ids(self): + request = {"prompt": "hi", "completion_token_ids": [], "temperature": 0, "top_p": 0} + processed = self.processor.process_request_dict(request, max_model_len=20) + # empty list is falsy, should not extend prompt_token_ids + self.assertEqual(processed["prompt_token_ids"], [2]) + + def test_process_request_dict_completion_token_ids_truncated(self): + # prompt "hi" -> [2], extend [10,11,12] -> [2,10,11,12] (len=4) + # max_model_len=3, 4 > 3 triggers truncation: [:3-1] = [:2] -> [2, 10] + request = {"prompt": "hi", "completion_token_ids": [10, 11, 12], "temperature": 0, "top_p": 0} + processed = self.processor.process_request_dict(request, max_model_len=3) + self.assertEqual(processed["prompt_token_ids"], [2, 10]) + def test_ids2tokens_and_clear_request_status(self): delta, _, _ = self.processor.ids2tokens([3], "task-1") self.assertEqual(delta, "3") diff --git a/tests/inter_communicator/test_e2w_queue.py b/tests/inter_communicator/test_e2w_queue.py index 97a17346c91..333249cc66d 100644 --- a/tests/inter_communicator/test_e2w_queue.py +++ b/tests/inter_communicator/test_e2w_queue.py @@ -16,14 +16,13 @@ import threading import time -import types import unittest import numpy as np import paddle -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda **_: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda **_: None from fastdeploy import envs from fastdeploy.engine.request import Request @@ -302,15 +301,15 @@ def test_wait_loops_and_tensor_conversion(self): client.get_finished_req() thread.join() - client.can_put_next_add_task_finished_flag.set(0) - thread = self._set_value_after_delay(client.can_put_next_add_task_finished_flag, 1) - client.put_finished_add_cache_task_req(["req-wait"]) + client.can_put_next_send_cache_finished_flag.set(0) + thread = self._set_value_after_delay(client.can_put_next_send_cache_finished_flag, 1) + client.put_finished_req([["req-wait", {"status": "ok"}]]) thread.join() - client.finished_add_cache_task_list.append(["req-wait"]) - client.client_get_finished_add_cache_task_flag[:] = [0] - thread = self._set_list_after_delay(client.client_get_finished_add_cache_task_flag, [1]) - client.get_finished_add_cache_task_req() + client.finished_send_cache_list.append(["req-wait", {"error": "bad"}]) + client.client_get_finish_send_cache_flag[:] = [0] + thread = self._set_list_after_delay(client.client_get_finish_send_cache_flag, [1]) + client.get_finished_req() thread.join() finally: paddle.set_device(previous_device) @@ -362,18 +361,6 @@ def test_finished_req_flow(self): finally: self._cleanup_queue_pair(server) - def test_finished_add_cache_task_req(self): - server, client = self._build_queue_pair() - try: - req_ids = ["req-2"] - self.assertTrue(client.put_finished_add_cache_task_req(req_ids)) - client.finished_add_cache_task_list.append(req_ids) - self.assertEqual(client.get_finished_add_cache_task_req(), req_ids) - self.assertEqual(client.get_finished_add_cache_task_req(), []) - self.assertEqual(client.can_put_next_add_task_finished_flag.get(), 1) - finally: - self._cleanup_queue_pair(server) - def test_disaggregated_queue(self): server, client = self._build_queue_pair() try: diff --git a/tests/inter_communicator/test_zmq_server.py b/tests/inter_communicator/test_zmq_server.py index 57c9a0c479a..17925f219f1 100644 --- a/tests/inter_communicator/test_zmq_server.py +++ b/tests/inter_communicator/test_zmq_server.py @@ -6,7 +6,6 @@ import tempfile import threading import time -import types import unittest from collections import defaultdict from unittest import mock @@ -16,8 +15,8 @@ import zmq from zmq.utils import jsonapi -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda **kwargs: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda **kwargs: None from fastdeploy import envs from fastdeploy.inter_communicator.zmq_server import ( diff --git a/tests/layers/test_deepgemm_fused_moe.py b/tests/layers/test_deepgemm_fused_moe.py index 5381ee866a3..4ec3e017e20 100644 --- a/tests/layers/test_deepgemm_fused_moe.py +++ b/tests/layers/test_deepgemm_fused_moe.py @@ -106,7 +106,10 @@ def __init__(self): # ep_size * this = max tokens buffer for masked GEMM; must be ≥ aligned M num_max_dispatch_tokens_per_rank=128, ) - self.scheduler_config = types.SimpleNamespace(max_num_batched_tokens=NUM_TOKENS) + self.scheduler_config = types.SimpleNamespace( + max_num_batched_tokens=NUM_TOKENS, + enable_moe_scores_elementwise_fuse=False, + ) self.parallel_config = types.SimpleNamespace(tensor_parallel_size=1) @@ -205,6 +208,22 @@ def hook(topk_ids): assert "topk_ids" in captured assert list(out.shape) == [NUM_TOKENS, HIDDEN_SIZE] + @requires_deepgemm + def test_apply_tp_noaux_tc_with_use_fused_true(self): + """noaux_tc path with enable_moe_scores_elementwise_fuse=True: triggers use_fused=True (no gate_out.cast).""" + layer = DummyLayer() + layer.topk_method = "noaux_tc" + gate = DummyGate(layer.num_local_experts) + method = _make_method() + + x = paddle.randn([NUM_TOKENS, HIDDEN_SIZE], dtype="bfloat16") + + # Enable flag to exercise the fused path (use_fused=True) + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse = True + + out = method.apply(layer, x, gate) + assert list(out.shape) == [NUM_TOKENS, HIDDEN_SIZE] + @requires_deepgemm def test_apply_tp_aux_path(self): """Non-noaux_tc: moe_topk_select → fp8_quant_blockwise → moe_permute → deepgemm → moe_unpermute.""" diff --git a/tests/layers/test_fused_cast_sigmoid_bias.py b/tests/layers/test_fused_cast_sigmoid_bias.py new file mode 100644 index 00000000000..21bfb0901fd --- /dev/null +++ b/tests/layers/test_fused_cast_sigmoid_bias.py @@ -0,0 +1,497 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib +import os +import sys +from unittest import mock + +import paddle +import paddle.nn.functional as F +import pytest + +from fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias import ( + fused_cast_sigmoid_bias, + is_available, +) + +DTYPE_MAP = { + "float16": paddle.float16, + "bfloat16": paddle.bfloat16, + "float32": paddle.float32, +} + + +def _ensure_gpu_test_environment(): + """Ensure GPU runtime and required custom ops are available for this test module.""" + if not paddle.is_compiled_with_cuda(): + pytest.skip( + "fused_cast_sigmoid_bias requires CUDA-enabled Paddle.", + allow_module_level=True, + ) + paddle.set_device("gpu") + + +_ensure_gpu_test_environment() + + +def reference_cast_sigmoid_bias(gate_out, bias, cast_type="float32"): + """Reference implementation: compute in fp32, cast output to cast_type.""" + gate_fp32 = gate_out.cast("float32") + scores_fp32 = F.sigmoid(gate_fp32) + scores_with_bias_fp32 = scores_fp32 + bias + scores = scores_fp32.cast(cast_type) + scores_with_bias = scores_with_bias_fp32.cast(cast_type) + return scores, scores_with_bias + + +def test_functionality(): + """Test basic functionality: correct shapes and dtypes (default cast_type=float32).""" + print("=" * 60) + print("Test 1: Functionality (default cast_type=float32)") + print("=" * 60) + + for dtype_name in ["float16", "bfloat16", "float32"]: + for num_tokens in [1, 7, 128, 1024]: + for num_experts in [8, 64, 128, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + assert scores.shape == [ + num_tokens, + num_experts, + ], f"scores shape mismatch: {scores.shape} vs {[num_tokens, num_experts]}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert scores.dtype == paddle.float32, f"scores dtype mismatch: {scores.dtype}" + assert ( + scores_with_bias.dtype == paddle.float32 + ), f"scores_with_bias dtype mismatch: {scores_with_bias.dtype}" + + # Sigmoid output should be in [0, 1] + assert bool(paddle.all(scores >= 0.0).item()) and bool( + paddle.all(scores <= 1.0).item() + ), "scores out of [0,1] range" + print(f" [PASS] dtype={dtype_name}") + + print(" All functionality tests passed.\n") + + +def test_functionality_cast_types(): + """Test functionality with different cast_type values.""" + print("=" * 60) + print("Test 1b: Functionality with different cast_type") + print("=" * 60) + + for input_dtype in ["float16", "bfloat16", "float32"]: + for cast_type in ["float16", "bfloat16", "float32"]: + expected_paddle_dtype = DTYPE_MAP[cast_type] + for num_tokens in [1, 64, 256]: + for num_experts in [8, 64, 256]: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + scores, scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + assert scores.shape == [num_tokens, num_experts], f"scores shape mismatch: {scores.shape}" + assert scores_with_bias.shape == [ + num_tokens, + num_experts, + ], f"scores_with_bias shape mismatch: {scores_with_bias.shape}" + assert ( + scores.dtype == expected_paddle_dtype + ), f"scores dtype mismatch: got {scores.dtype}, expected {expected_paddle_dtype}" + assert ( + scores_with_bias.dtype == expected_paddle_dtype + ), f"scores_with_bias dtype mismatch: got {scores_with_bias.dtype}, expected {expected_paddle_dtype}" + + print(f" [PASS] input_dtype={input_dtype}, cast_type={cast_type}") + + print(" All cast_type functionality tests passed.\n") + + +def test_accuracy(): + """Test numerical accuracy against reference implementation (default cast_type=float32).""" + print("=" * 60) + print("Test 2: Accuracy (default cast_type=float32)") + print("=" * 60) + + test_cases = [ + ("float16", 1, 8), + ("float16", 128, 256), + ("float16", 1024, 256), + ("bfloat16", 1, 8), + ("bfloat16", 128, 256), + ("bfloat16", 1024, 256), + ("float32", 1, 8), + ("float32", 128, 256), + ("float32", 1024, 256), + ] + + for dtype_name, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias) + + # Compare + scores_diff = paddle.abs(fused_scores - ref_scores).max().item() + scores_bias_diff = paddle.abs(fused_scores_with_bias - ref_scores_with_bias).max().item() + + atol = 1e-6 if dtype_name == "float32" else 1e-3 + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] dtype={dtype_name}, tokens={num_tokens}, experts={num_experts} | " + f"scores_max_diff={scores_diff:.2e}, scores_with_bias_max_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for dtype={dtype_name}, tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, scores_bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All accuracy tests passed.\n") + + +def test_accuracy_cast_types(): + """Test numerical accuracy with different cast_type values.""" + print("=" * 60) + print("Test 2b: Accuracy with different cast_type") + print("=" * 60) + + # (input_dtype, cast_type, num_tokens, num_experts) + test_cases = [ + # cast to float32 (original behavior) + ("float16", "float32", 128, 256), + ("bfloat16", "float32", 128, 256), + ("float32", "float32", 128, 256), + # cast to float16 + ("float16", "float16", 128, 256), + ("bfloat16", "float16", 128, 256), + ("float32", "float16", 128, 256), + # cast to bfloat16 + ("float16", "bfloat16", 128, 256), + ("bfloat16", "bfloat16", 128, 256), + ("float32", "bfloat16", 128, 256), + # different shapes + ("bfloat16", "float16", 1, 8), + ("bfloat16", "float16", 1024, 256), + ("float16", "bfloat16", 1, 8), + ("float16", "bfloat16", 1024, 256), + ] + + for input_dtype, cast_type, num_tokens, num_experts in test_cases: + gate_out = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = paddle.randn([num_experts], dtype="float32") + + # Fused kernel + fused_scores, fused_scores_with_bias = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Reference + ref_scores, ref_scores_with_bias = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + + # Compare in float32 for stable diff computation + scores_diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + scores_bias_diff = ( + paddle.abs(fused_scores_with_bias.cast("float32") - ref_scores_with_bias.cast("float32")).max().item() + ) + + # Tolerance depends on cast_type precision + if cast_type == "float32": + atol = 1e-6 + elif cast_type == "bfloat16": + atol = 1e-2 # bfloat16 has fewer mantissa bits + else: # float16 + atol = 1e-3 + + passed = scores_diff < atol and scores_bias_diff < atol + + status = "PASS" if passed else "FAIL" + print( + f" [{status}] input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts} | " + f"scores_diff={scores_diff:.2e}, bias_diff={scores_bias_diff:.2e}" + ) + + if not passed: + raise AssertionError( + f"Accuracy test failed for input={input_dtype}, cast_type={cast_type}, " + f"tokens={num_tokens}, experts={num_experts}. " + f"scores_diff={scores_diff}, bias_diff={scores_bias_diff}, atol={atol}" + ) + + print(" All cast_type accuracy tests passed.\n") + + +def test_accuracy_extreme_values(): + """Test accuracy with extreme input values.""" + print("=" * 60) + print("Test 3: Accuracy with extreme values") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for dtype_name in ["float16", "bfloat16"]: + # Large positive values -> sigmoid ~ 1.0 + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=dtype_name) + bias = paddle.zeros([num_experts], dtype="float32") + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large positive: max_diff={diff:.2e}") + + # Large negative values -> sigmoid ~ 0.0 + gate_out = paddle.full([num_tokens, num_experts], -10.0, dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + print(f" [{'PASS' if diff < 1e-5 else 'FAIL'}] dtype={dtype_name}, large negative: max_diff={diff:.2e}") + + # Zero values -> sigmoid = 0.5 + gate_out = paddle.zeros([num_tokens, num_experts], dtype=dtype_name) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias) + diff = paddle.abs(fused_scores - ref_scores).max().item() + assert diff < 1e-6, f"Zero input test failed: diff={diff}" + print(f" [PASS] dtype={dtype_name}, zeros: max_diff={diff:.2e}") + + print(" All extreme value tests passed.\n") + + +def test_accuracy_extreme_values_cast_types(): + """Test accuracy with extreme values across different cast_type values.""" + print("=" * 60) + print("Test 3b: Accuracy with extreme values + different cast_type") + print("=" * 60) + + num_tokens, num_experts = 64, 256 + + for input_dtype in ["float16", "bfloat16"]: + for cast_type in ["float16", "bfloat16", "float32"]: + bias = paddle.zeros([num_experts], dtype="float32") + + # Large positive + gate_out = paddle.full([num_tokens, num_experts], 10.0, dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + status = "PASS" if diff < atol else "FAIL" + print(f" [{status}] input={input_dtype}, cast={cast_type}, " f"large positive: diff={diff:.2e}") + + # Zero values + gate_out = paddle.zeros([num_tokens, num_experts], dtype=input_dtype) + fused_scores, _ = fused_cast_sigmoid_bias(gate_out, bias, cast_type) + ref_scores, _ = reference_cast_sigmoid_bias(gate_out, bias, cast_type) + diff = paddle.abs(fused_scores.cast("float32") - ref_scores.cast("float32")).max().item() + atol = 1e-2 if cast_type == "bfloat16" else 1e-5 + assert diff < atol, f"Zero input test failed: input={input_dtype}, cast={cast_type}, diff={diff}" + print(f" [PASS] input={input_dtype}, cast={cast_type}, " f"zeros: diff={diff:.2e}") + + print(" All extreme value cast_type tests passed.\n") + + +@pytest.mark.skipif( + os.getenv("RUN_PERFORMANCE_TESTS") != "1", + reason="Performance benchmark is disabled by default. Set RUN_PERFORMANCE_TESTS=1 to enable.", +) +def test_performance(): + """Benchmark fused kernel vs reference implementation using CUDA events.""" + print("=" * 60) + print("Test 4: Performance (CUDA event timing)") + print("=" * 60) + + configs = [ + ("bfloat16", 1, 256), # single token decode + ("bfloat16", 8, 256), # small batch decode + ("bfloat16", 64, 256), # medium batch + ("bfloat16", 256, 256), # typical DeepSeek-V3 config + ("bfloat16", 1024, 256), # large prefill + ("bfloat16", 4096, 256), # very large prefill + ] + + warmup_iters = 100 + bench_iters = 500 + + for dtype_name, num_tokens, num_experts in configs: + gate_out = paddle.randn([num_tokens, num_experts], dtype=dtype_name) + bias = paddle.randn([num_experts], dtype="float32") + + # Warmup fused + for _ in range(warmup_iters): + fused_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark fused with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + fused_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + fused_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + # Warmup reference + for _ in range(warmup_iters): + reference_cast_sigmoid_bias(gate_out, bias) + paddle.device.synchronize() + + # Benchmark reference with CUDA events + start_event = paddle.device.cuda.Event(enable_timing=True) + end_event = paddle.device.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(bench_iters): + reference_cast_sigmoid_bias(gate_out, bias) + end_event.record() + paddle.device.synchronize() + ref_time = start_event.elapsed_time(end_event) / bench_iters * 1e3 # us + + speedup = ref_time / fused_time if fused_time > 0 else float("inf") + print( + f" tokens={num_tokens:5d}, experts={num_experts:3d} | " + f"ref={ref_time:8.1f}us, fused={fused_time:8.1f}us, speedup={speedup:.2f}x" + ) + + print() + print(" Note: The CUDA custom op fuses cast+sigmoid+bias into a single kernel,") + print(" eliminating 2 intermediate tensors and reducing kernel launches from 3 to 1.") + print(" Expected speedup: ~3x over the reference 3-op implementation.") + print(" Performance benchmark complete.\n") + + +def test_is_available(): + """Test is_available() function returns True when GPU ops are available.""" + print("=" * 60) + print("Test: is_available()") + print("=" * 60) + + # In normal GPU test environment, is_available should return True + result = is_available() + assert isinstance(result, bool), f"is_available() should return bool, got {type(result)}" + assert result is True, f"is_available() should return True when GPU ops are compiled, got {result}" + print(f" [PASS] is_available() returned {result}") + print(" is_available() test passed.\n") + + +def test_import_error(): + """Test that ImportError is raised when GPU ops are not available.""" + print("=" * 60) + print("Test 5: Import error handling") + print("=" * 60) + + module_name = "fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias" + gpu_ops_module = "fastdeploy.model_executor.ops.gpu" + + # Save original module references + original_module = sys.modules.pop(module_name, None) + original_gpu_ops = sys.modules.get(gpu_ops_module) + + try: + # Mock the GPU ops module to raise ImportError on import + with mock.patch.dict(sys.modules, {gpu_ops_module: None}): + # Re-import the module so it picks up the mocked (missing) GPU ops + reloaded = importlib.import_module(module_name) + importlib.reload(reloaded) + + # The module should load successfully, but calling the function + # should raise ImportError because the cuda op is unavailable. + dummy_gate = paddle.randn([1, 8], dtype="float32") + dummy_bias = paddle.randn([8], dtype="float32") + try: + reloaded.fused_cast_sigmoid_bias(dummy_gate, dummy_bias) + raise AssertionError("Expected ImportError was not raised") + except ImportError as e: + assert "fused_cast_sigmoid_bias is not available" in str(e), f"Unexpected error message: {e}" + print(f" [PASS] ImportError raised with correct message: {e}") + finally: + # Restore original modules + sys.modules.pop(module_name, None) + if original_module is not None: + sys.modules[module_name] = original_module + if original_gpu_ops is not None: + sys.modules[gpu_ops_module] = original_gpu_ops + + print(" Import error handling test passed.\n") + + +def test_is_available_when_ops_unavailable(): + """Test is_available() returns False when GPU ops are not available.""" + print("=" * 60) + print("Test: is_available() when ops unavailable") + print("=" * 60) + + module_name = "fastdeploy.model_executor.layers.moe.fused_cast_sigmoid_bias" + gpu_ops_module = "fastdeploy.model_executor.ops.gpu" + + # Save original module references + original_module = sys.modules.pop(module_name, None) + original_gpu_ops = sys.modules.get(gpu_ops_module) + + try: + # Mock the GPU ops module to raise ImportError on import + with mock.patch.dict(sys.modules, {gpu_ops_module: None}): + # Re-import the module so it picks up the mocked (missing) GPU ops + reloaded = importlib.import_module(module_name) + importlib.reload(reloaded) + + # is_available should return False when ops are not available + result = reloaded.is_available() + assert isinstance(result, bool), f"is_available() should return bool, got {type(result)}" + assert result is False, f"is_available() should return False when GPU ops are unavailable, got {result}" + print(f" [PASS] is_available() returned {result} when ops unavailable") + finally: + # Restore original modules + sys.modules.pop(module_name, None) + if original_module is not None: + sys.modules[module_name] = original_module + if original_gpu_ops is not None: + sys.modules[gpu_ops_module] = original_gpu_ops + + print(" is_available() when ops unavailable test passed.\n") + + +if __name__ == "__main__": + print("Running fused_cast_sigmoid_bias tests...\n") + + test_is_available() + test_functionality() + test_functionality_cast_types() + test_accuracy() + test_accuracy_cast_types() + test_accuracy_extreme_values() + test_accuracy_extreme_values_cast_types() + test_import_error() + test_is_available_when_ops_unavailable() + if os.getenv("RUN_PERFORMANCE_TESTS") == "1": + test_performance() + else: + print("Skipping performance benchmark. Set RUN_PERFORMANCE_TESTS=1 to enable.\n") + + print("=" * 60) + print("All tests passed!") + print("=" * 60) diff --git a/tests/layers/test_fused_moe_cutlass_backend.py b/tests/layers/test_fused_moe_cutlass_backend.py index 2e8ea281daa..f6f92fb44da 100644 --- a/tests/layers/test_fused_moe_cutlass_backend.py +++ b/tests/layers/test_fused_moe_cutlass_backend.py @@ -23,8 +23,8 @@ import paddle import pytest -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda *args, **kwargs: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None iluvatar_stub = types.ModuleType("fastdeploy.model_executor.ops.iluvatar") iluvatar_stub.moe_expert_ffn = lambda *args, **kwargs: None @@ -35,10 +35,15 @@ iluvatar_stub.prefill_fused_paged_attention = lambda *args, **kwargs: None sys.modules["fastdeploy.model_executor.ops.iluvatar"] = iluvatar_stub +import fastdeploy # noqa: E402 from fastdeploy.model_executor.layers import utils as layer_utils from fastdeploy.model_executor.layers.moe import fused_moe_cutlass_backend as backend +def align(x, y): + return (x + y - 1) // y * y + + class DummyQuantConfig: def __init__(self, algo="weight_only_int8", is_quantized=False, is_checkpoint_bf16=False): self.algo = algo @@ -53,6 +58,7 @@ class DummyFDConfig: def __init__(self, load_choices="default_v1"): self.model_config = types.SimpleNamespace(model="dummy", prefix_layer_name="prefix") self.load_config = types.SimpleNamespace(load_choices=load_choices) + self.scheduler_config = types.SimpleNamespace(enable_moe_scores_elementwise_fuse=False) class DummyLayer(paddle.nn.Layer): @@ -388,7 +394,17 @@ def combine(self, ffn_out, topk_idx, topk_weights, handle, quant_group_size=-1): np.testing.assert_allclose(out.numpy(), np.full((1, 2), 5.0)) def test_apply_tp_with_dispatch_and_reduce(self, monkeypatch): - def fake_get_moe_scores(gate_out, n_group, topk_group, top_k, routed_scaling_factor, bias, renormalize): + def fake_get_moe_scores( + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, + ): return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) def fake_dispatch(*args, **kwargs): @@ -707,3 +723,295 @@ def test_weight_only_prequanted_and_int4_create(self): int4_method.create_weights( int4_layer, num_experts=2, hidden_size=4, moe_intermediate_size=2, model_format="paddle" ) + + +# --------------------------------------------------------------------------- +# Real-op tests for FD_USE_PHI_MOE_PERMUTE=True (w16a16, moe_permute path) +# --------------------------------------------------------------------------- + +from fastdeploy.platforms import current_platform # noqa: E402 + +_CUDA_AVAILABLE = current_platform.is_cuda() +requires_cuda = pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA required") + + +class RealMoELayer(paddle.nn.Layer): + """Minimal bf16 MoE layer with real weights for moe_permute path testing.""" + + def __init__(self, num_experts=4, hidden_size=64, moe_intermediate_size=32, top_k=2): + super().__init__() + self.fd_config = DummyFDConfig() + self.num_experts = num_experts + self.num_local_experts = num_experts + self.hidden_size = hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.top_k = top_k + self.topk_method = "noaux_tc" + self.n_group = 1 + self.topk_group = 1 + self.routed_scaling_factor = 1.0 + self.with_bias = False + self.ep_size = 1 + self.ep_rank = 0 + self.layer_idx = 0 + self.weight_dtype = "bfloat16" + self.is_quantized = False + self.activation = "swiglu" + self.moe_quant_config = types.SimpleNamespace(moe_dynamic_quant=False, hadamard_block_size=128) + self.gate_correction_bias = self.create_parameter( + shape=[1, num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + paddle.seed(0) + self.up_gate_proj_weight = self.create_parameter( + shape=[num_experts, hidden_size, 2 * moe_intermediate_size], + dtype="bfloat16", + ) + self.down_proj_weight = self.create_parameter( + shape=[num_experts, moe_intermediate_size, hidden_size], + dtype="bfloat16", + ) + self.up_gate_proj_weight.set_value( + paddle.randn([num_experts, hidden_size, 2 * moe_intermediate_size]).cast("bfloat16") * 0.01 + ) + self.down_proj_weight.set_value( + paddle.randn([num_experts, moe_intermediate_size, hidden_size]).cast("bfloat16") * 0.01 + ) + + +class SimpleLinearGate(paddle.nn.Layer): + def __init__(self, hidden_size, num_experts): + super().__init__() + self.weight = self.create_parameter(shape=[hidden_size, num_experts], dtype="float32") + + def forward(self, x): + return paddle.matmul(x.cast("float32"), self.weight) + + +class TestMoePermuteTrueRealOps: + """Real-op tests for FD_USE_PHI_MOE_PERMUTE=True on the w16a16 path.""" + + def _build(self, num_experts=4, hidden_size=64, moe_intermediate_size=32, top_k=2): + layer = RealMoELayer( + num_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=moe_intermediate_size, + top_k=top_k, + ) + gate = SimpleLinearGate(hidden_size, num_experts) + method = backend.CutlassMoEMethod(None) + method.moe_quant_type = "w16a16" + return layer, gate, method + + @requires_cuda + def test_apply_tp_moe_permute_real_ops(self, monkeypatch): + """FD_USE_PHI_MOE_PERMUTE=True + w16a16: real moe_permute/moe_unpermute/ + count_tokens_per_expert_func/moe_expert_ffn all called end-to-end.""" + monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True) + + num_tokens, hidden_size = 8, 64 + layer, gate, method = self._build(hidden_size=hidden_size) + + paddle.seed(42) + x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16") + + # Spy: confirm moe_permute is called, moe_expert_dispatch is NOT + permute_called = {"v": False} + dispatch_called = {"v": False} + original_permute = paddle.nn.functional.moe_permute + + def spy_permute(*args, **kwargs): + permute_called["v"] = True + return original_permute(*args, **kwargs) + + monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute) + monkeypatch.setattr( + backend, + "moe_expert_dispatch", + lambda *a, **kw: (_ for _ in ()).throw(AssertionError("moe_expert_dispatch must not be called")), + ) + + out = method.apply_tp(layer, x, gate) + + assert permute_called["v"], "moe_permute was not called" + assert not dispatch_called["v"], "moe_expert_dispatch must not be called" + assert list(out.shape) == [num_tokens, hidden_size], f"wrong output shape: {out.shape}" + assert not paddle.isnan(out).any(), "output contains NaN" + assert not paddle.isinf(out).any(), "output contains Inf" + + def test_apply_tp_noaux_tc_with_use_fused_true(self, monkeypatch): + def fake_get_moe_scores( + gate_out, + n_group, + topk_group, + top_k, + routed_scaling_factor, + bias, + renormalize, + topk_reduce_func=None, + use_fused_cast=False, + ): + return gate_out, paddle.to_tensor([[0.6, 0.4]]), paddle.to_tensor([[0, 1]]) + + def fake_dispatch(*args, **kwargs): + return ( + paddle.ones([1, 2]), + paddle.to_tensor([1, 0]), + paddle.to_tensor([0]), + paddle.to_tensor([[0.6, 0.4]]), + paddle.to_tensor([[0, 1]]), + paddle.to_tensor([0]), + None, + None, + ) + + def fake_reduce(*args, **kwargs): + return paddle.ones([1, 2]) * 5 + + def fake_compute_ffn(*args, **kwargs): + return paddle.ones([1, 2]) * 2 + + monkeypatch.setattr(backend, "get_moe_scores", fake_get_moe_scores, raising=False) + monkeypatch.setattr(backend, "moe_expert_dispatch", fake_dispatch, raising=False) + monkeypatch.setattr(backend, "moe_expert_reduce", fake_reduce, raising=False) + + # Mock compute_ffn on the class to avoid real GPU op data type issues + monkeypatch.setattr(backend.CutlassMoEMethod, "compute_ffn", fake_compute_ffn) + + # Enable enable_moe_scores_elementwise_fuse and force is_cuda=True to trigger use_fused = True + monkeypatch.setattr(backend, "current_platform", types.SimpleNamespace(is_cuda=lambda: True)) + layer = DummyLayer(with_bias=False) + layer.topk_method = "noaux_tc" + layer.fd_config.scheduler_config.enable_moe_scores_elementwise_fuse = True + # Add necessary attributes for compute_ffn access + layer.up_gate_proj_weight = paddle.zeros([2, 2 * 1], dtype="float16") + layer.down_proj_weight = paddle.zeros([2, 2], dtype="float16") + layer.activation = "silu" + + method = backend.CutlassMoEMethod(None) + + x = paddle.ones([1, 2]) + gate = paddle.nn.Identity() + + method.apply(layer, x, gate) + + @requires_cuda + def test_apply_ep_prefill_moe_permute_real_ops(self, monkeypatch): + """FD_USE_PHI_MOE_PERMUTE=True + w16a16: EP prefill uses real moe_permute / + moe_unpermute / count_tokens_per_expert_func / moe_expert_ffn end-to-end. + The EP dispatch/combine are stubbed (no real NCCL needed). + Use num_tokens=128 and num_experts=4 so each expert gets exactly 64 tokens + (128 * top_k=2 / 4 experts = 64), satisfying moe_expert_ffn alignment.""" + monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True) + + # 128 tokens, top_k=2, 4 experts → 64 tokens/expert (128-aligned after padding) + num_tokens, hidden_size = 128, 64 + layer, gate, method = self._build(num_experts=4, hidden_size=hidden_size, top_k=2) + + paddle.seed(42) + x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16") + + # Stub only the EP communication runner (dispatch/combine). + # All on-device compute (moe_permute, moe_expert_ffn, moe_unpermute) runs for real. + class StubEPRunner: + ep_engine = types.SimpleNamespace(async_finish=False) + + def moe_select(self, _layer, gate_out): + n = gate_out.shape[0] + # Route token i to experts (i % E) and ((i+1) % E) so all experts + # get tokens and recv_num_tokens_per_expert_list is accurate. + E = _layer.num_local_experts + idx0 = paddle.arange(n, dtype="int64") % E + idx1 = (paddle.arange(n, dtype="int64") + 1) % E + topk_ids = paddle.stack([idx0, idx1], axis=1) + topk_weights = paddle.ones([n, _layer.top_k], dtype="float32") / _layer.top_k + return topk_ids, topk_weights + + def dispatch(self, x, topk_idx, topk_weights, **kwargs): + # Pass tensors through unchanged — single-rank, no real communication. + # Compute accurate recv_num_tokens_per_expert_list from topk_idx. + E = layer.num_local_experts + counts = [ + align(int((topk_idx == e).sum().item()), kwargs.get("expert_alignment", 1)) for e in range(E) + ] + return ( + x, + topk_idx, + topk_weights, + counts, + object(), + types.SimpleNamespace(current_stream_wait=lambda: None), + ) + + def combine(self, ffn_out, handle, recv_topk_weights): + return ffn_out, types.SimpleNamespace(current_stream_wait=lambda: None) + + method.ep_prefill_runner = StubEPRunner() + + # Spy: confirm moe_permute is called inside ep_prefill + permute_called = {"v": False} + original_permute = paddle.nn.functional.moe_permute + + def spy_permute(*args, **kwargs): + permute_called["v"] = True + return original_permute(*args, **kwargs) + + monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute) + + out = method.apply_ep_prefill(layer, x, gate) + + assert permute_called["v"], "moe_permute was not called in ep_prefill path" + assert len(out.shape) == 2, f"wrong output ndim: {out.shape}" + assert out.shape[1] == hidden_size, f"wrong hidden_size: {out.shape}" + assert not paddle.isnan(out).any(), "output contains NaN" + assert not paddle.isinf(out).any(), "output contains Inf" + + @requires_cuda + def test_apply_tp_moe_permute_non_noaux_tc(self, monkeypatch): + """FD_USE_PHI_MOE_PERMUTE=True + w16a16 + topk_method != 'noaux_tc': + the else-branch calls moe_topk_select instead of get_moe_scores, + then proceeds through moe_permute / moe_expert_ffn / moe_unpermute.""" + monkeypatch.setattr(backend.fastdeploy.envs, "FD_USE_PHI_MOE_PERMUTE", True) + + num_tokens, hidden_size = 8, 64 + layer, gate, method = self._build(hidden_size=hidden_size) + # Switch to non-noaux_tc to exercise the else-branch (moe_topk_select) + layer.topk_method = "greedy" + + paddle.seed(7) + x = paddle.randn([num_tokens, hidden_size], dtype="bfloat16") + + # Spy on which routing function is invoked + get_moe_scores_called = {"v": False} + moe_topk_select_called = {"v": False} + permute_called = {"v": False} + + original_get_moe_scores = backend.get_moe_scores + original_moe_topk_select = fastdeploy.model_executor.ops.gpu.moe_topk_select + original_permute = paddle.nn.functional.moe_permute + + def spy_get_moe_scores(*args, **kwargs): + get_moe_scores_called["v"] = True + return original_get_moe_scores(*args, **kwargs) + + def spy_moe_topk_select(*args, **kwargs): + moe_topk_select_called["v"] = True + return original_moe_topk_select(*args, **kwargs) + + def spy_permute(*args, **kwargs): + permute_called["v"] = True + return original_permute(*args, **kwargs) + + monkeypatch.setattr(backend, "get_moe_scores", spy_get_moe_scores) + monkeypatch.setattr(fastdeploy.model_executor.ops.gpu, "moe_topk_select", spy_moe_topk_select) + monkeypatch.setattr(paddle.nn.functional, "moe_permute", spy_permute) + + out = method.apply_tp(layer, x, gate) + + assert not get_moe_scores_called["v"], "get_moe_scores must NOT be called for non-noaux_tc" + assert moe_topk_select_called["v"], "moe_topk_select must be called for non-noaux_tc" + assert permute_called["v"], "moe_permute must be called" + assert list(out.shape) == [num_tokens, hidden_size], f"wrong shape: {out.shape}" + assert not paddle.isnan(out).any(), "output contains NaN" + assert not paddle.isinf(out).any(), "output contains Inf" diff --git a/tests/layers/test_fused_moe_triton_backend.py b/tests/layers/test_fused_moe_triton_backend.py index 1140cf72b16..b42db5cc3d3 100644 --- a/tests/layers/test_fused_moe_triton_backend.py +++ b/tests/layers/test_fused_moe_triton_backend.py @@ -20,11 +20,12 @@ import sys import types +import numpy as np import paddle import pytest -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda scope=None: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None if not hasattr(paddle.nn.functional, "swiglu"): paddle.nn.functional.swiglu = lambda x: x @@ -37,6 +38,7 @@ def __init__(self, is_checkpoint_bf16=False, weight_block_size=(2, 2), name_valu self.weight_block_size = weight_block_size self._name_value = name_value self.deepgemm_scale_ue8m0 = False + self.moe_blockwise_gemm_scale_ue8m0 = False def name(self): return self._name_value @@ -57,6 +59,12 @@ class DummyFDConfig: def __init__(self, load_choices="default_v1"): self.load_config = DummyLoadConfig(load_choices) self.model_config = types.SimpleNamespace(enable_cache=False) + self.scheduler_config = types.SimpleNamespace( + enable_moe_scores_elementwise_fuse=False, + splitwise_role="mixed", + max_num_seqs=8, + max_num_batched_tokens=256, + ) class DummyGate(paddle.nn.Layer): @@ -88,9 +96,14 @@ def __init__( self.n_group = 1 self.topk_group = 1 self.routed_scaling_factor = 1.0 + self.routed_scaling_factor_learnable = False self.renormalize = True self.gate_correction_bias = paddle.zeros([num_local_experts], dtype="float32") self.topk_method = "noaux_tc" + self.with_bias = False + self.ep_size = 1 + self.activation = "swiglu" + self.moe_quant_config = types.SimpleNamespace() self.fd_config = DummyFDConfig(load_choices) self.weight_dtype = weight_dtype self.quant_method = DummyQuantMethod(quant_config) @@ -209,10 +222,15 @@ def test_backend_imports_kernel_module(self, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) reloaded = importlib.reload(backend) assert hasattr(reloaded, "fused_moe_kernel_paddle") + # Restore the real module: reload() permanently rebinds module-level names + # (e.g. fused_moe_kernel_bf16) to the fake, and monkeypatch cannot undo that. + # A second reload after monkeypatch restores sys.modules fixes the binding. + monkeypatch.undo() + importlib.reload(backend) def test_triton_weight_only_create_and_apply(self, fake_ops, monkeypatch): quant_config = DummyQuantConfig(is_checkpoint_bf16=False) @@ -321,7 +339,7 @@ def test_wfp8afp8_method_apply_paths(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) @@ -395,7 +413,7 @@ def test_wfp8afp8_apply_noaux_and_empty(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) _ = method.apply( @@ -435,7 +453,7 @@ def test_tensorwise_prequant_and_apply(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) @@ -458,7 +476,7 @@ def test_python_op_fused_moe_kernel_paddle(self, fake_ops, monkeypatch): monkeypatch.setitem( sys.modules, "fastdeploy.model_executor.layers.moe.triton_moe_kernels", - types.SimpleNamespace(fused_moe_kernel_paddle=kernel), + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), ) monkeypatch.setattr( paddle.static, @@ -642,6 +660,7 @@ def test_blockwise_process_weights_ue8m0_branch(self, fake_ops, monkeypatch): """Test the quant_weight_ue8m0 branch in BlockWiseFP8MoEMethod.process_weights_after_loading.""" quant_config = DummyQuantConfig(is_checkpoint_bf16=True, weight_block_size=(128, 128)) quant_config.deepgemm_scale_ue8m0 = True + quant_config.moe_blockwise_gemm_scale_ue8m0 = True layer = DummyLayer(quant_config, weight_dtype="bfloat16") method = backend.BlockWiseFP8MoEMethod(quant_config) method.create_weights(layer, model_format="torch") @@ -695,3 +714,783 @@ def fake_transform_scale_ue8m0(sf, mn, weight_block_size=None): # Verify the quant_weight_ue8m0 branch was executed assert len(quant_calls) > 0, "quant_weight_ue8m0 should have been called" assert len(transform_calls) > 0, "transform_scale_ue8m0 should have been called" + + def test_triton_weight_only_apply_noaux_tc_with_fd_enable_rl(self, fake_ops, monkeypatch): + quant_config = DummyQuantConfig(is_checkpoint_bf16=False) + layer = DummyLayer(quant_config) + layer.topk_method = "noaux_tc" + method = backend.TritonWeightOnlyMoEMethod(quant_config) + method.create_weights(layer, model_format="torch") + + layer._up_weights = [ + paddle.arange(layer.hidden_size * layer.moe_intermediate_size * 2, dtype="float32").reshape( + [layer.hidden_size, layer.moe_intermediate_size * 2] + ) + for _ in range(layer.num_local_experts) + ] + layer._down_weights = [ + paddle.arange(layer.moe_intermediate_size * layer.hidden_size, dtype="float32").reshape( + [layer.moe_intermediate_size, layer.hidden_size] + ) + for _ in range(layer.num_local_experts) + ] + method.process_loaded_weights(layer, state_dict={}) + + kernel = DummyKernel() + monkeypatch.setattr(backend, "fused_moe_kernel_paddle", kernel, raising=False) + + # Set FD_ENABLE_RL=True to trigger use_fused = False at line 313 + # This should trigger gate_out.cast('float32') at line 315 + monkeypatch.setattr(backend.fastdeploy.envs, "FD_ENABLE_RL", True) + + x = paddle.randn([1, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + _ = method.apply(layer, x, gate, topk_ids_hookfunc=hook) + assert "topk_ids" in captured + + def test_python_op_learnable_scaling(self, fake_ops, monkeypatch): + """routed_scaling_factor_learnable=True: per_expert_scale applied to topk_weights inside python_op.""" + quant_config = DummyQuantConfig(is_checkpoint_bf16=False, weight_block_size=(2, 2)) + layer = DummyLayer(quant_config) + layer.routed_scaling_factor_learnable = True + layer.per_expert_scale = paddle.ones([layer.num_local_experts], dtype="float32") + + kernel = DummyKernel() + monkeypatch.setitem( + sys.modules, + "fastdeploy.model_executor.layers.moe.triton_moe_kernels", + types.SimpleNamespace(fused_moe_kernel_paddle=kernel, fused_moe_kernel_bf16=kernel), + ) + monkeypatch.setattr( + paddle.static, + "MetaTensor", + lambda shape, dtype: types.SimpleNamespace(shape=shape, dtype=dtype), + raising=False, + ) + + x = paddle.randn([2, layer.hidden_size], dtype="float32") + gate = DummyGate(layer.num_local_experts) + gate_out = gate(x) + + up_weight = paddle.randn( + [layer.num_local_experts, layer.moe_intermediate_size * 2, layer.hidden_size], dtype="float32" + ) + down_weight = paddle.randn( + [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size], dtype="float32" + ) + up_scale = paddle.ones([layer.num_local_experts, 2, 2], dtype="float32") + down_scale = paddle.ones([layer.num_local_experts, 2, 2], dtype="float32") + + captured = {} + + def hook(topk_ids): + captured["topk"] = topk_ids + + config = {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1} + + _ = backend.python_op_fused_moe_kernel_paddle( + x, + up_weight, + up_scale, + down_weight, + down_scale, + gate_out, + layer.gate_correction_bias, + layer.top_k, + up_weight.shape[1], + down_weight.shape[1], + layer.num_local_experts, + layer.moe_intermediate_size, + layer.hidden_size, + config, + quant_config, + hook, + ) + + assert "topk" in captured + + +class DummyBF16Kernel: + """ + Simulates fused_moe_kernel_bf16[grid](...). + Writes zeros into the output tensor (3rd positional argument). + """ + + def __init__(self): + self.calls = [] + + def __getitem__(self, grid): + def _runner(*args, **kwargs): + # output tensor is the 3rd positional argument (index 2) + if len(args) > 2 and isinstance(args[2], paddle.Tensor): + args[2].set_value(paddle.zeros_like(args[2])) + self.calls.append({"grid": grid, "kwargs": kwargs}) + + return _runner + + +class DummyTL: + """Minimal stub for triton.language so tests don't need a real Triton install.""" + + bfloat16 = "bfloat16" + float16 = "float16" + + +class TestTritonMoEMethod: + """Unit tests for TritonMoEMethod. + + Pattern mirrors TestFusedMoeTritonBackend: + - DummyLayer / DummyGate / DummyFDConfig (reused from module top) + - fake_ops fixture patches routing + preprocess ops + - DummyBF16Kernel patches fused_moe_kernel_bf16 + - No real GPU kernels are executed; output shapes / attributes are verified + """ + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + def _make_layer(self, num_experts=2, hidden_size=8, intermediate_size=4, top_k=2): + layer = DummyLayer( + quant_config=None, + num_local_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + top_k=top_k, + weight_dtype="bfloat16", + ) + return layer + + def _create_weights(self, method, layer): + """Call create_weights with the mandatory kwargs that the real MoE layer supplies. + + TritonMoEMethod targets the CUDA non-torch weight layout: + up_gate_proj_weight: [E, hidden_size, inter*2] (K-major) + down_proj_weight: [E, inter, hidden_size] (K-major) + Therefore we must NOT pass model_format="torch"; any non-"torch" value + (or omitting the key) lets UnquantizedFusedMoEMethod take the CUDA branch. + """ + method.create_weights( + layer, + model_format="default", + num_experts=layer.num_local_experts, + hidden_size=layer.hidden_size, + moe_intermediate_size=layer.moe_intermediate_size, + ) + + def _patch_bf16_kernel(self, monkeypatch): + kernel = DummyBF16Kernel() + monkeypatch.setattr(backend, "fused_moe_kernel_bf16", kernel, raising=False) + # Patch tl so that `compute_type=tl.bfloat16` inside apply() does not + # raise NameError when triton is not installed in the test environment. + monkeypatch.setattr(backend, "tl", DummyTL(), raising=False) + return kernel + + # ------------------------------------------------------------------ + # __init__ / basic construction + # ------------------------------------------------------------------ + + def test_init_sets_weight_attrs(self): + """TritonMoEMethod.__init__ must expose the two weight attr names.""" + method = backend.TritonMoEMethod() + assert "up_gate_proj_weight" in method.added_weight_attrs + assert "down_proj_weight" in method.added_weight_attrs + + def test_init_none_quant_config(self): + method = backend.TritonMoEMethod(quant_config=None) + assert method.quant_config is None + + # ------------------------------------------------------------------ + # create_weights + # ------------------------------------------------------------------ + + def test_create_weights_registers_parameters(self): + """After create_weights the layer should have up_gate_proj_weight and down_proj_weight.""" + method = backend.TritonMoEMethod() + layer = self._make_layer() + self._create_weights(method, layer) + assert hasattr(layer, "up_gate_proj_weight") + assert hasattr(layer, "down_proj_weight") + + def test_create_weights_shapes(self): + """Weight tensors must have the correct [E, K, N] / [E, N, K] layout.""" + E, H, N = 3, 8, 4 + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) + self._create_weights(method, layer) + # up_gate: [E, hidden_size, intermediate*2] + assert list(layer.up_gate_proj_weight.shape) == [E, H, N * 2] + # down: [E, intermediate, hidden_size] + assert list(layer.down_proj_weight.shape) == [E, N, H] + + # ------------------------------------------------------------------ + # process_loaded_weights + # ------------------------------------------------------------------ + + def test_process_loaded_weights_stacks_experts(self): + """process_loaded_weights must stack per-expert tensors into the stacked param.""" + E, H, N = 2, 8, 4 + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=E, hidden_size=H, intermediate_size=N) + self._create_weights(method, layer) + + # Provide per-expert tensors via extract_moe_ffn_weights + up_weights = [paddle.ones([H, N * 2], dtype="bfloat16") * (i + 1) for i in range(E)] + down_weights = [paddle.ones([N, H], dtype="bfloat16") * (i + 1) for i in range(E)] + layer._up_weights = up_weights + layer._down_weights = down_weights + + method.process_loaded_weights(layer, state_dict={}) + + # After stacking, shape should be [E, ...] + assert list(layer.up_gate_proj_weight.shape) == [E, H, N * 2] + assert list(layer.down_proj_weight.shape) == [E, N, H] + # Verify each expert's data is correctly stacked (expert i has value i+1) + for i in range(E): + expected_up = float(i + 1) + expected_down = float(i + 1) + actual_up = float(layer.up_gate_proj_weight[i].cast("float32").mean()) + actual_down = float(layer.down_proj_weight[i].cast("float32").mean()) + assert ( + abs(actual_up - expected_up) < 1e-3 + ), f"Expert {i} up_gate weight mean={actual_up}, expected {expected_up}" + assert ( + abs(actual_down - expected_down) < 1e-3 + ), f"Expert {i} down_proj weight mean={actual_down}, expected {expected_down}" + + # ------------------------------------------------------------------ + # ------------------------------------------------------------------ + # _get_default_config — tile heuristic + # ------------------------------------------------------------------ + + def test_get_default_config_decode(self): + """M<=32 decode path → 16x64x64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=4, E=8) + assert cfg["BLOCK_SIZE_M"] == 16 + assert cfg["BLOCK_SIZE_N"] == 64 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_mid(self): + """96 < M <= 512 mid path → 64x128x64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=128, E=8) + assert cfg["BLOCK_SIZE_M"] == 64 + assert cfg["BLOCK_SIZE_N"] == 128 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_prefill(self): + """M > 512 prefill path → 128x128x64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=1024, E=8) + assert cfg["BLOCK_SIZE_M"] == 128 + assert cfg["BLOCK_SIZE_N"] == 128 + assert cfg["BLOCK_SIZE_K"] == 64 + + def test_get_default_config_boundary_32(self): + """M==32 is decode (<=32).""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=32, E=8) + assert cfg["BLOCK_SIZE_M"] == 16 + + def test_get_default_config_boundary_96(self): + """M==96 is small-mid (32 < M <= 96) → BLOCK_SIZE_M=32.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=96, E=8) + assert cfg["BLOCK_SIZE_M"] == 32 + + def test_get_default_config_boundary_512(self): + """M==512 is mid (<=512) → BLOCK_SIZE_M=64.""" + method = backend.TritonMoEMethod() + cfg = method._get_default_config(M=512, E=8) + assert cfg["BLOCK_SIZE_M"] == 64 + + def test_get_default_config_has_group_size_m(self): + """All configs must include GROUP_SIZE_M key.""" + method = backend.TritonMoEMethod() + for M in (1, 64, 1024): + cfg = method._get_default_config(M=M, E=8) + assert "GROUP_SIZE_M" in cfg + + def test_get_default_config_block_n_boundary(self): + """M<=64 → BLOCK_SIZE_N=64; M>64 → BLOCK_SIZE_N=128.""" + method = backend.TritonMoEMethod() + cfg64 = method._get_default_config(M=64, E=8) + assert cfg64["BLOCK_SIZE_N"] == 64 + cfg65 = method._get_default_config(M=65, E=8) + assert cfg65["BLOCK_SIZE_N"] == 128 + + def test_get_default_config_group_m_16(self): + """tokens_per_expert > 128 → GROUP_SIZE_M=16.""" + method = backend.TritonMoEMethod() + # M=1024, E=1 → tokens_per_expert=1024 > 128 → group_m=16 + cfg = method._get_default_config(M=1024, E=1) + assert cfg["GROUP_SIZE_M"] == 16 + + def test_get_default_config_group_m_1(self): + """tokens_per_expert <= 128 → GROUP_SIZE_M=1.""" + method = backend.TritonMoEMethod() + # M=128, E=8 → tokens_per_expert=16 <= 128 → group_m=1 + cfg = method._get_default_config(M=128, E=8) + assert cfg["GROUP_SIZE_M"] == 1 + + def test_get_default_config_num_warps(self): + """M<=128 → num_warps=4; M>128 → num_warps=8.""" + method = backend.TritonMoEMethod() + cfg128 = method._get_default_config(M=128, E=8) + assert cfg128["num_warps"] == 4 + cfg256 = method._get_default_config(M=256, E=8) + assert cfg256["num_warps"] == 8 + + def test_get_default_config_num_stages(self): + """M<=32 → num_stages=4; M>32 → num_stages=3.""" + method = backend.TritonMoEMethod() + cfg32 = method._get_default_config(M=32, E=8) + assert cfg32["num_stages"] == 4 + cfg33 = method._get_default_config(M=33, E=8) + assert cfg33["num_stages"] == 3 + + # ------------------------------------------------------------------ + # apply — empty-batch fast path + # ------------------------------------------------------------------ + + def test_apply_empty_batch_returns_zero_tensor(self, fake_ops, monkeypatch): + """apply() with 0 tokens must return a zero tensor of shape [0, hidden_size].""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.zeros([0, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [0, layer.hidden_size] + + # ------------------------------------------------------------------ + # apply — normal forward (noaux_tc routing path) + # ------------------------------------------------------------------ + + def test_apply_noaux_tc_output_shape(self, fake_ops, monkeypatch): + """apply() noaux_tc path: output shape must be [token_num, hidden_size].""" + T, H = 4, 8 + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=H) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([T, H], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [T, H] + + def test_apply_noaux_tc_topk_hook_called(self, fake_ops, monkeypatch): + """topk_ids_hookfunc must be called with topk_ids kwarg during apply().""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + captured = {} + + def hook(topk_ids): + captured["topk_ids"] = topk_ids + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + method.apply(layer, x, DummyGate(layer.num_local_experts), topk_ids_hookfunc=hook) + + assert "topk_ids" in captured + + def test_apply_noaux_tc_kernel_called_twice(self, fake_ops, monkeypatch): + """fused_moe_kernel_bf16 must be launched twice (GEMM1 + GEMM2) per forward pass.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + method.apply(layer, x, DummyGate(layer.num_local_experts)) + + assert len(kernel.calls) == 2, f"Expected 2 kernel launches (GEMM1 + GEMM2), got {len(kernel.calls)}" + + # ------------------------------------------------------------------ + # apply — non-noaux routing path (moe_topk_select) + # ------------------------------------------------------------------ + + def test_apply_aux_routing_path(self, fake_ops, monkeypatch): + """When topk_method != 'noaux_tc', the moe_topk_select path is used.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + layer.topk_method = "aux" + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + captured = {} + + def hook(topk_ids): + captured["ids"] = topk_ids + + x = paddle.randn([3, layer.hidden_size], dtype="bfloat16") + out = method.apply(layer, x, DummyGate(layer.num_local_experts), topk_ids_hookfunc=hook) + + assert list(out.shape) == [3, layer.hidden_size] + assert "ids" in captured + + # ------------------------------------------------------------------ + # apply_tp delegates to apply + # ------------------------------------------------------------------ + + def test_apply_tp_delegates_to_apply(self, fake_ops, monkeypatch): + """apply_tp() must produce the same output shape as apply().""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply_tp(layer, x, gate) + + assert list(out.shape) == [2, layer.hidden_size] + + # ------------------------------------------------------------------ + # EP methods raise NotImplementedError + # ------------------------------------------------------------------ + + def test_apply_ep_prefill_raises(self): + method = backend.TritonMoEMethod() + layer = self._make_layer() + with pytest.raises(NotImplementedError): + method.apply_ep_prefill(layer, None, None) + + def test_apply_ep_decode_raises(self): + method = backend.TritonMoEMethod() + layer = self._make_layer() + with pytest.raises(NotImplementedError): + method.apply_ep_decode(layer, None, None) + + # ------------------------------------------------------------------ + # apply — kernel argument verification + # ------------------------------------------------------------------ + + def test_apply_kernel_even_ks_true(self, fake_ops, monkeypatch): + """When hidden_size is divisible by BLOCK_SIZE_K, even_Ks=True in GEMM1.""" + method = backend.TritonMoEMethod() + # hidden_size=64, BLOCK_SIZE_K=64 → even_Ks=True for GEMM1 + layer = self._make_layer(hidden_size=64, intermediate_size=32) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + assert kernel.calls[0]["kwargs"]["even_Ks"] is True + + def test_apply_kernel_even_ks_false(self, fake_ops, monkeypatch): + """When hidden_size is NOT divisible by BLOCK_SIZE_K, even_Ks=False in GEMM1.""" + method = backend.TritonMoEMethod() + # hidden_size=8, BLOCK_SIZE_K=64 → even_Ks=False for GEMM1 + layer = self._make_layer(hidden_size=8, intermediate_size=4) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + assert kernel.calls[0]["kwargs"]["even_Ks"] is False + + def test_apply_gemm2_top_k_always_1(self, fake_ops, monkeypatch): + """GEMM2 must always be called with top_k=1 (flat token-expert pairs).""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8, top_k=4) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + assert kernel.calls[0]["kwargs"]["top_k"] == layer.top_k + assert kernel.calls[1]["kwargs"]["top_k"] == 1 + + def test_apply_gemm1_no_mul_weight_gemm2_mul_weight(self, fake_ops, monkeypatch): + """GEMM1 has MUL_ROUTED_WEIGHT=False, GEMM2 has MUL_ROUTED_WEIGHT=True.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert kernel.calls[0]["kwargs"]["MUL_ROUTED_WEIGHT"] is False + assert kernel.calls[1]["kwargs"]["MUL_ROUTED_WEIGHT"] is True + + def test_apply_large_batch_config(self, fake_ops, monkeypatch): + """Large token count picks larger tile config (BLOCK_SIZE_M=128, num_warps=8).""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + kernel = self._patch_bf16_kernel(monkeypatch) + + # 1024 tokens → prefill config: BLOCK_SIZE_M=128 + x = paddle.randn([1024, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert len(kernel.calls) == 2 + assert kernel.calls[0]["kwargs"]["BLOCK_SIZE_M"] == 128 + assert kernel.calls[0]["kwargs"]["num_warps"] == 8 + + def test_apply_single_token_output_shape(self, fake_ops, monkeypatch): + """Single token decode scenario.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(num_experts=128, hidden_size=16, intermediate_size=8, top_k=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + x = paddle.randn([1, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + + assert list(out.shape) == [1, layer.hidden_size] + + def test_get_moe_method_triton_branch(self, monkeypatch): + """get_moe_method() returns TritonMoEMethod when FD_MOE_BACKEND='triton' and is_cuda().""" + from fastdeploy.model_executor.layers.moe import moe as moe_module + + monkeypatch.setattr(moe_module, "current_platform", types.SimpleNamespace(is_cuda=lambda: True)) + monkeypatch.setattr(moe_module.envs, "FD_MOE_BACKEND", "triton") + result = moe_module.get_moe_method() + assert isinstance(result, backend.TritonMoEMethod) + + def test_apply_use_fused_false(self, fake_ops, monkeypatch): + """FD_ENABLE_RL=True triggers use_fused=False branch (gate_out.cast('float32')).""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + monkeypatch.setattr(backend.fastdeploy.envs, "FD_ENABLE_RL", True) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + out = method.apply(layer, x, gate) + assert list(out.shape) == [2, layer.hidden_size] + + def test_apply_tp_with_topk_reduce_func(self, fake_ops, monkeypatch): + """topk_reduce_func attribute is passed through to get_moe_scores.""" + method = backend.TritonMoEMethod() + layer = self._make_layer(hidden_size=8) + layer.topk_reduce_func = lambda x: x + self._create_weights(method, layer) + self._patch_bf16_kernel(monkeypatch) + + scores_kwargs = {} + + def tracking_get_moe_scores(*args, **kwargs): + scores_kwargs.update(kwargs) + gate_out = args[0] + token_num = gate_out.shape[0] + top_k = args[3] + topk_ids = paddle.zeros([token_num, top_k], dtype="int64") + topk_weights = paddle.ones([token_num, top_k], dtype="float32") + return gate_out, topk_weights, topk_ids + + monkeypatch.setattr(backend, "get_moe_scores", tracking_get_moe_scores) + + x = paddle.randn([2, layer.hidden_size], dtype="bfloat16") + gate = DummyGate(layer.num_local_experts) + method.apply(layer, x, gate) + + assert "topk_reduce_func" in scores_kwargs + + +# =========================================================================== +# Precision tests: TritonMoEMethod vs. CutlassMoEMethod (BF16) +# =========================================================================== + + +def _make_precision_layer_pair(num_experts, hidden_size, intermediate_size, top_k): + """ + Build a DummyLayer with random BF16 weights and a TritonMoEMethod. + + Weight layout (CUDA non-torch): [E, H, 2N] for up_gate_proj, [E, N, H] for down_proj. + Returns (layer, None, triton_method) for compatibility with existing test signatures. + """ + layer = DummyLayer( + quant_config=None, + num_local_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + top_k=top_k, + weight_dtype="bfloat16", + ) + + triton_method = backend.TritonMoEMethod() + + # Create weight parameters (CUDA non-torch layout) + triton_method.create_weights( + layer, + model_format="default", + num_experts=num_experts, + hidden_size=hidden_size, + moe_intermediate_size=intermediate_size, + ) + + # Fill with Xavier-like random BF16 weights to produce meaningful output magnitudes. + # W1: [E, H, 2N] — scale by 1/sqrt(H) so GEMM1 output ~O(1) + # W2: [E, N, H] — scale by 1/sqrt(N) so GEMM2 output ~O(1) + paddle.seed(42) + w1_scale = 1.0 / (hidden_size**0.5) + w2_scale = 1.0 / (intermediate_size**0.5) + layer.up_gate_proj_weight.set_value((paddle.randn(layer.up_gate_proj_weight.shape) * w1_scale).cast("bfloat16")) + layer.down_proj_weight.set_value((paddle.randn(layer.down_proj_weight.shape) * w2_scale).cast("bfloat16")) + return layer, None, triton_method + + +def _uniform_gate(layer): + """Gate that outputs uniform logits so every expert gets equal probability.""" + + class _Gate(paddle.nn.Layer): + def __init__(self, num_experts): + super().__init__() + self.num_experts = num_experts + + def forward(self, x): + return paddle.ones([x.shape[0], self.num_experts], dtype="float32") + + return _Gate(layer.num_local_experts) + + +# Shapes to exercise: (token_num, hidden_size, intermediate_size, num_experts, top_k) +# Small/medium sizes to keep test runtime reasonable. +_PRECISION_SHAPES = [ + pytest.param(1, 64, 32, 8, 2, id="decode_T1_H64"), + pytest.param(16, 64, 32, 8, 2, id="decode_T16_H64"), + pytest.param(64, 128, 64, 8, 2, id="mid_T64_H128"), + pytest.param(128, 128, 64, 8, 2, id="mid_T128_H128_E8"), + pytest.param(256, 256, 128, 8, 4, id="prefill_T256_H256"), +] + + +@pytest.mark.skipif(not paddle.is_compiled_with_cuda(), reason="requires CUDA") +# @pytest.mark.skipif(not _triton_ops_available(), reason="triton MoE ops not available (custom ops not compiled)") +class TestTritonMoEPrecision: + """ + Precision tests: Triton BF16 path vs. Cutlass BF16 path. + + Both paths are activated in production via the FD_MOE_BACKEND env var + (triton vs cutlass). This test verifies they produce numerically equivalent + results on the same shared BF16 weights and identical inputs. + + All tests run real GPU kernels (no mocking). + Tolerance: atol=1e-2, rtol=1e-2 (both kernels use BF16 arithmetic with + fp32 accumulation; differences come from tile ordering / rounding). + """ + + # Tolerance for comparing two independent BF16 GEMM implementations. + # BF16 has ~7-bit mantissa (eps ~0.008). After GEMM1 + SwiGLU + GEMM2, + # rounding differences accumulate. Use np.allclose style: + # |triton - cutlass| <= ATOL + RTOL * |cutlass| + ATOL = 1e-3 + RTOL = 1e-3 + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_vs_cutlass(self, T, H, N, E, K): + """Triton BF16 MoE output must agree with CUTLASS BF16 MoE output. + + Both paths use the same weight layout, routing logic, and BF16 arithmetic. + Differences should only come from tile ordering / rounding in GEMM. + """ + from fastdeploy.model_executor.layers.moe.fused_moe_cutlass_backend import ( + CutlassMoEMethod, + ) + + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + + # CUTLASS method shares the same weights (already created by _make_precision_layer_pair) + cutlass_method = CutlassMoEMethod(None) + + paddle.seed(0) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + + # Use a deterministic non-uniform gate to ensure consistent routing + # across multiple calls of noaux_tc (avoids tie-breaking ambiguity) + class _DeterministicGate(paddle.nn.Layer): + def __init__(self, num_experts, T): + super().__init__() + self.num_experts = num_experts + paddle.seed(123) + self._scores = paddle.randn([T, num_experts], dtype="float32") * 2.0 + + def forward(self, x): + return self._scores[: x.shape[0]] + + gate = _DeterministicGate(E, T) + + # --- Run Triton path --- + triton_out = triton_method.apply(layer, x, gate).cast("float32").numpy() + + # --- Run CUTLASS path --- + cutlass_out = cutlass_method.apply(layer, x, gate).cast("float32").numpy() + + # np.allclose style: |a - b| <= atol + rtol * |b| + tol = self.ATOL + self.RTOL * np.abs(cutlass_out) + violations = np.abs(triton_out - cutlass_out) > tol + num_violations = int(violations.sum()) + total_elements = triton_out.size + + assert num_violations == 0, ( + f"[T={T},H={H},N={N},E={E},K={K}] " + f"{num_violations}/{total_elements} elements exceed tolerance " + f"(atol={self.ATOL}, rtol={self.RTOL}). " + f"Max abs diff: {float(np.abs(triton_out - cutlass_out).max()):.2e}, " + f"max |cutlass|: {float(np.abs(cutlass_out).max()):.2e}" + ) + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_output_shape(self, T, H, N, E, K): + """Output shape must always be [T, H] regardless of batch size.""" + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + out = triton_method.apply(layer, x, gate) + assert list(out.shape) == [T, H], f"Expected [{T}, {H}], got {list(out.shape)}" + + @pytest.mark.parametrize("T,H,N,E,K", _PRECISION_SHAPES) + def test_triton_output_dtype_is_bfloat16(self, T, H, N, E, K): + """Output dtype must match input dtype (bfloat16).""" + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = (paddle.randn([T, H]) * 0.1).cast("bfloat16") + gate = _uniform_gate(layer) + out = triton_method.apply(layer, x, gate) + assert out.dtype == paddle.bfloat16, f"Expected bfloat16, got {out.dtype}" + + def test_zero_input_gives_zero_output(self): + """All-zero input must produce all-zero output.""" + T, H, N, E, K = 8, 64, 32, 8, 2 + layer, _, triton_method = _make_precision_layer_pair(E, H, N, K) + x = paddle.zeros([T, H], dtype="bfloat16") + gate = _uniform_gate(layer) + + out = triton_method.apply(layer, x, gate).cast("float32").numpy() + np.testing.assert_allclose( + out, + np.zeros_like(out), + atol=1e-6, + err_msg="triton: zero input should produce zero output", + ) diff --git a/tests/layers/test_kv_cache_int8_dynamic_quant_backend.py b/tests/layers/test_kv_cache_int8_dynamic_quant_backend.py index 17a393ee11e..f679be08b31 100644 --- a/tests/layers/test_kv_cache_int8_dynamic_quant_backend.py +++ b/tests/layers/test_kv_cache_int8_dynamic_quant_backend.py @@ -92,6 +92,7 @@ def __init__(self): "max_model_len": 2048, "head_dim": 128, "num_hidden_layers": 2, + "enable_mm": False, "causal": True, "start_layer_index": 0, "rope_3d": False, @@ -124,6 +125,8 @@ def __init__(self): "model_type": "main", }, )() + self.enable_mm_runtime = self.model_config.enable_mm + self.enable_rope_3d_runtime = self.model_config.enable_mm class DummyLayer: diff --git a/tests/layers/test_plas_attention.py b/tests/layers/test_plas_attention.py index 663b27dc9ab..e593595fa5a 100644 --- a/tests/layers/test_plas_attention.py +++ b/tests/layers/test_plas_attention.py @@ -12,8 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys + import paddle +tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, tests_dir) + +from e2e.utils.serving_utils import clean_ports + try: from fastdeploy.model_executor.ops.gpu import ( fused_block_mean_and_rope, @@ -338,6 +346,9 @@ def test_plas_attention(self): self.compare_attn(attn_out, qk_gate_topk_idx) def test_server(self): + # Clean ports before starting the test + clean_ports() + if get_cur_cu_seq_len_k is None: return os.environ["FD_ATTENTION_BACKEND"] = "PLAS_ATTN" diff --git a/tests/layers/test_sampler.py b/tests/layers/test_sampler.py index cdc58eb33d1..72475d4ba49 100644 --- a/tests/layers/test_sampler.py +++ b/tests/layers/test_sampler.py @@ -26,8 +26,8 @@ import fastdeploy # noqa: F401 -if not hasattr(paddle, "compat"): - paddle.compat = types.SimpleNamespace(enable_torch_proxy=lambda *args, **kwargs: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None # Optional runtime deps are intentionally stubbed for unit isolation. if "triton" not in sys.modules: @@ -310,6 +310,7 @@ def test_speculative_sampler_basic(monkeypatch): enf_gen_phase_tag=False, verify_strategy="topp", accept_policy="normal", + num_speculative_tokens=1, ), parallel_config=types.SimpleNamespace(prefill_one_step_stop=False), ) @@ -327,7 +328,7 @@ def test_speculative_sampler_basic(monkeypatch): m.top_p_normalized_logprobs_flag = True m.share_inputs = { "seq_lens_this_time": paddle.to_tensor([[1]], dtype="int64"), - "accept_num": paddle.to_tensor([1], dtype="int64"), + "accept_num": paddle.to_tensor([1], dtype="int32"), } gathered = sampler.gather_logprobs(sampler.compute_logprobs(logits, m), 0, paddle.to_tensor([1], dtype="int64")) assert gathered.logprob_token_ids.shape[1] == 1 diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index ef75fe5d4e8..11247e3fe06 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -97,12 +97,12 @@ def _create_default_sampling_metadata( return fake_sampling_metadata -def _create_fd_config(max_model_len, method=None): +def _create_fd_config(max_model_len, method=None, verify_strategy="topp"): model_config: Mock = Mock() model_config.max_model_len = max_model_len model_config.architectures = ["test_model"] model_config.mm_max_tokens_per_item = None - speculative_config = SpeculativeConfig({"method": method} if method else {}) + speculative_config = SpeculativeConfig({"method": method, "verify_strategy": verify_strategy}) graph_opt_config = GraphOptimizationConfig({}) scheduler_config = SchedulerConfig({}) parallel_config = ParallelConfig({}) @@ -187,7 +187,7 @@ def test_speculative_sampler(): max_draft_token_num = 1 # Use ngram method for speculative decoding - fd_config = _create_fd_config(max_model_len, method="ngram") + fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp") sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len) logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) @@ -196,7 +196,15 @@ def test_speculative_sampler(): increment_value = (max_draft_token_num + 1) * 4 sampler = SpeculativeSampler(fd_config) - sampler(logits, sampling_metadata, max_model_len, share_inputs, token_num_output_cpu, increment_value) + sampler( + logits, + sampling_metadata, + max_model_len, + share_inputs, + token_num_output_cpu, + increment_value, + real_bsz=batch_size, + ) def test_speculative_sampler_logprobs(): @@ -208,7 +216,7 @@ def test_speculative_sampler_logprobs(): max_draft_token_num = 1 # Use ngram method for speculative decoding - fd_config = _create_fd_config(max_model_len, method="ngram") + fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp") share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0) sampling_metadata.share_inputs = share_inputs @@ -221,7 +229,15 @@ def test_speculative_sampler_logprobs(): for logprobs_mode in logprobs_mode_list: fd_config.model_config.logprobs_mode = logprobs_mode sampler = SpeculativeSampler(fd_config) - sampler(logits, sampling_metadata, max_model_len, share_inputs, token_num_output_cpu, increment_value) + sampler( + logits, + sampling_metadata, + max_model_len, + share_inputs, + token_num_output_cpu, + increment_value, + real_bsz=batch_size, + ) def test_mtp_sampler(): diff --git a/tests/layers/test_triton_sampler.py b/tests/layers/test_triton_sampler.py new file mode 100644 index 00000000000..f4c12e10082 --- /dev/null +++ b/tests/layers/test_triton_sampler.py @@ -0,0 +1,429 @@ +""" +Unit tests for the triton sampling path introduced in commit 16e692f. + +Covers: + - _apply_triton_top_k_top_p / apply_top_k_top_p_triton Python wrapper + - _random_sample / seeded_gumbel_noise Python wrapper + - Sampler.forward_cuda triton branch (FD_SAMPLING_CLASS="triton") + - SpeculativeSampler triton branches +""" + +import sys +import types + +import paddle +import pytest + +import fastdeploy # noqa: F401 + +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None + +# Stub triton for unit isolation (same pattern as test_sampler.py). +if "triton" not in sys.modules: + triton_stub = types.ModuleType("triton") + triton_stub.jit = lambda fn: fn + triton_stub.next_power_of_2 = lambda n: 1 << (n - 1).bit_length() + triton_lang_stub = types.ModuleType("triton.language") + triton_lang_stub.constexpr = int + sys.modules["triton"] = triton_stub + sys.modules["triton.language"] = triton_lang_stub + +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata + +# Must import after stubs are in place. +from fastdeploy.model_executor.layers.sample.sampler import ( + Sampler, + SpeculativeSampler, + _apply_triton_top_k_top_p, + _random_sample, +) +from fastdeploy.spec_decode import VerifyStrategy + +# --------------------------------------------------------------------------- +# Fixtures & helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _patch_gpu_deps(monkeypatch): + """Patch only GPU-specific calls so Python wrapper code can execute on CPU.""" + import fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton as triton_mod + + # Patch the kernel launch inside apply_top_k_top_p_triton: replace + # _topk_topp_kernel so it becomes a no-op (logits left unchanged → + # equivalent to "keep all" when no real GPU masking happens). + # This lets the Python wrapper (lines 830-936) run for coverage. + def _fake_kernel_call(grid, kwargs): + pass + + monkeypatch.setattr(triton_mod._topk_topp_kernel, "__call__", _fake_kernel_call) + + # Patch paddle.device.cuda.get_device_properties used inside + # apply_top_k_top_p_triton to avoid "no CUDA device" error. + fake_props = types.SimpleNamespace(multi_processor_count=1) + monkeypatch.setattr( + paddle.device.cuda, + "get_device_properties", + lambda idx: fake_props, + ) + + # Patch _seeded_gumbel_kernel similarly so seeded_gumbel_noise (lines + # 960-981) runs its Python logic without real GPU. + def _fake_gumbel_kernel_call(grid, kwargs): + pass + + monkeypatch.setattr(triton_mod._seeded_gumbel_kernel, "__call__", _fake_gumbel_kernel_call) + + # Patch batched_count_greater_than (used in gather_logprobs). + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.batched_count_greater_than", + lambda x, y: (x >= y).sum(-1), + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.logprobs.batched_count_greater_than", + lambda x, y: (x >= y).sum(-1), + ) + + # Patch current_platform so is_cuda() returns True (needed for + # build_sampling_params import). + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.current_platform.is_cuda", + lambda: True, + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.current_platform.is_xpu", + lambda: False, + ) + + +@pytest.fixture +def mock_ops(monkeypatch): + """Patch heavy GPU ops that are not the focus of triton tests.""" + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.apply_penalty_multi_scores", + lambda *a, **k: a[1], + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.apply_speculative_penalty_multi_scores", + lambda *a, **k: a[2], + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.min_p_sampling", + lambda probs, *a, **k: probs, + ) + return monkeypatch + + +@pytest.fixture +def triton_mode(monkeypatch): + """Set FD_SAMPLING_CLASS to triton for the duration of the test.""" + import fastdeploy.envs as envs + + monkeypatch.setattr(envs, "FD_SAMPLING_CLASS", "triton") + + +def _create_metadata(batch_size=1, min_seq_len=1, max_seq_len=3, max_num_logprobs=None, **overrides): + m = SamplingMetadata( + temperature=paddle.full(shape=[batch_size, 1], fill_value=0.9, dtype="float32"), + top_p=paddle.full(shape=[batch_size, 1], fill_value=0.7, dtype="float32"), + prompt_lens=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"), + step_idx=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"), + token_ids_all=paddle.full(shape=[batch_size, max_seq_len], fill_value=-1, dtype="int64"), + frequency_penalties=paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32"), + presence_penalties=paddle.full(shape=[batch_size, 1], fill_value=0.0, dtype="float32"), + repetition_penalties=paddle.full(shape=[batch_size, 1], fill_value=1.0, dtype="float32"), + min_dec_lens=paddle.full(shape=[batch_size, 1], fill_value=min_seq_len, dtype="int64"), + bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"), + bad_words_token_len=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"), + eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"), + min_p=paddle.zeros([batch_size], dtype="float32"), + seed=paddle.full([batch_size, 1], 7, dtype="int64"), + logits_processors=None, + ) + m.max_num_logprobs = max_num_logprobs + m.top_k = paddle.full([batch_size, 1], 5, dtype="int64") + m.top_k_list = [5 for _ in range(batch_size)] + m.min_p_list = [0.0 for _ in range(batch_size)] + m.enable_early_stop = True + m.stop_flags = paddle.zeros([batch_size, 1], dtype="int32") + m.share_inputs = { + "seq_lens_this_time": paddle.ones([batch_size, 1], dtype="int64"), + "seq_lens_encoder": paddle.zeros([batch_size, 1], dtype="int64"), + "seq_lens_decoder": paddle.zeros([batch_size, 1], dtype="int64"), + } + for k, v in overrides.items(): + setattr(m, k, v) + return m + + +def _make_stubbed_sampler(mode="processed_logprobs"): + s = Sampler.__new__(Sampler) + s.guided_decoding = types.SimpleNamespace(apply_token_mask=lambda logits, p_done_idxs: logits) + s.logprobs_mode = mode + s.early_stopper = types.SimpleNamespace(process=lambda probs, next_tokens, stop_flags: None) + return s + + +# --------------------------------------------------------------------------- +# Tests for _apply_triton_top_k_top_p (direct call) +# --------------------------------------------------------------------------- + + +class TestApplyTritonTopKTopP: + """Tests for _apply_triton_top_k_top_p.""" + + def test_returns_logits_unchanged_when_both_none(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=None, top_k=None) + assert paddle.equal_all(result, logits) + + def test_top_p_only_no_error(self): + """top_p filtering runs through apply_top_k_top_p_triton wrapper.""" + logits = paddle.to_tensor([[1.0, 2.0, 5.0]], dtype="float32") + top_p = paddle.to_tensor([[0.7]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p) + assert result.shape == [1, 3] + + def test_top_k_disabled_when_list_none(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + top_p = paddle.to_tensor([[1.0]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p, top_k=None, top_k_list=None) + assert result.shape == [1, 3] + + def test_return_mask_false(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + top_p = paddle.to_tensor([[0.9]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p, return_mask=False) + assert isinstance(result, paddle.Tensor) + + def test_return_mask_true(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + top_p = paddle.to_tensor([[0.5]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p, return_mask=True) + assert isinstance(result, tuple) + assert len(result) == 2 + logits_out, mask = result + assert logits_out.shape == [1, 3] + assert mask.shape == [1, 3] + assert mask.dtype == paddle.bool + + def test_output_dtype_is_float32(self): + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float16") + top_p = paddle.to_tensor([[0.9]], dtype="float32") + result = _apply_triton_top_k_top_p(logits, top_p=top_p) + assert result.dtype == paddle.float32 + + def test_combined_top_k_top_p(self): + logits = paddle.to_tensor([[1.0, 5.0, 3.0, 2.0, 4.0]], dtype="float32") + top_p = paddle.to_tensor([[0.5]], dtype="float32") + top_k = paddle.to_tensor([[3]], dtype="int64") + top_k_list = [3] + result = _apply_triton_top_k_top_p(logits, top_p=top_p, top_k=top_k, top_k_list=top_k_list) + assert result.shape == [1, 5] + + +# --------------------------------------------------------------------------- +# Tests for _random_sample (direct call) +# --------------------------------------------------------------------------- + + +class TestRandomSample: + """Tests for _random_sample.""" + + def test_output_shape_and_dtype(self): + probs = paddle.to_tensor([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], dtype="float32") + result = _random_sample(probs) + assert result.shape == [2, 1] + assert result.dtype == paddle.int64 + + def test_without_seed(self): + probs = paddle.to_tensor([[0.1, 0.2, 0.7]], dtype="float32") + result = _random_sample(probs, topp_seed=None) + assert 0 <= result[0, 0].item() < 3 + + def test_with_seed(self): + probs = paddle.to_tensor([[0.1, 0.2, 0.7]], dtype="float32") + seed = paddle.to_tensor([[42]], dtype="int64") + result = _random_sample(probs, topp_seed=seed) + assert result.shape == [1, 1] + + def test_greedy_with_peak_distribution(self): + probs = paddle.zeros([1, 10], dtype="float32") + probs[0, 5] = 1.0 + result = _random_sample(probs) + assert result[0, 0].item() == 5 + + def test_batch_multiple_requests(self): + probs = paddle.to_tensor([[0.1, 0.2, 0.7], [0.0, 0.0, 1.0]], dtype="float32") + result = _random_sample(probs) + assert result.shape == [2, 1] + assert 0 <= result[0, 0].item() < 3 + assert result[1, 0].item() == 2 + + +# --------------------------------------------------------------------------- +# Tests for Sampler.forward_cuda with triton path +# --------------------------------------------------------------------------- + + +class TestSamplerTritonPath: + """Test Sampler.forward_cuda with FD_SAMPLING_CLASS=triton.""" + + def test_forward_cuda_triton_path(self, mock_ops, triton_mode): + """Sampler.forward_cuda should call _apply_triton_top_k_top_p and _random_sample.""" + sampler = _make_stubbed_sampler("processed_logprobs") + m = _create_metadata(batch_size=1, max_num_logprobs=2) + + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + output = sampler.forward_cuda(logits, m) + assert output.sampled_token_ids.shape == [1, 1] + assert output.logprobs_tensors is not None + + +# --------------------------------------------------------------------------- +# Tests for SpeculativeSampler triton branches +# --------------------------------------------------------------------------- + + +def _make_spec_sampler(verify_strategy=VerifyStrategy.TARGET_MATCH, spec_method=None): + """Create a SpeculativeSampler with stubbed internals.""" + s = SpeculativeSampler.__new__(SpeculativeSampler) + s.verify_strategy = verify_strategy + s.spec_method = spec_method # None → NAIVE path + s.enf_gen_phase_tag = False + s.config_accept_all = False + s.config_reject_all = False + s.speculative_benchmark_mode = False + s.speculative_max_candidate_len = 1 + s.speculative_verify_window = 2 + s.think_end_id = 1 + s.line_break_id = 2 + s.logprobs_mode = "processed_logprobs" + return s + + +def _spec_share_inputs(batch_size=1): + return { + "seq_lens_this_time": paddle.ones([batch_size, 1], dtype="int64"), + "seq_lens_encoder": paddle.zeros([batch_size, 1], dtype="int64"), + "cu_seqlens_q_output": paddle.to_tensor([0] + [1] * batch_size, dtype="int32"), + "batch_id_per_token_output": paddle.zeros([batch_size], dtype="int32"), + "accept_tokens": paddle.zeros([batch_size, 1], dtype="int64"), + "accept_num": paddle.zeros([batch_size], dtype="int32"), + "draft_tokens": paddle.zeros([batch_size, 1], dtype="int64"), + "stop_flags": paddle.zeros([batch_size, 1], dtype="int32"), + "is_block_step": paddle.zeros([batch_size], dtype="int32"), + "reasoning_status": paddle.zeros([batch_size, 1], dtype="int32"), + "max_dec_len": paddle.full([batch_size, 1], 1024, dtype="int64"), + "step_idx": paddle.zeros([batch_size, 1], dtype="int64"), + } + + +class TestSpeculativeSamplerTritonPath: + """Test SpeculativeSampler triton branches (lines 916, 1016-1017, 1120-1132).""" + + def test_verify_and_sample_target_match_triton(self, mock_ops, triton_mode, monkeypatch): + """_verify_and_sample with TARGET_MATCH + triton → calls _random_sample (line 916).""" + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.build_sampling_params", + lambda *a, **k: ( + paddle.to_tensor([[0.9]], dtype="float32"), + paddle.to_tensor([[5]], dtype="int64"), + paddle.to_tensor([[7]], dtype="int64"), + ), + ) + # verify_draft_tokens is lazily imported inside _verify_and_sample + import fastdeploy.model_executor.ops.gpu as gpu_ops + + monkeypatch.setattr(gpu_ops, "verify_draft_tokens", lambda *a, **k: None) + monkeypatch.setattr(gpu_ops, "top_p_candidates", lambda *a, **k: (None, None, None)) + + sampler = _make_spec_sampler(verify_strategy=VerifyStrategy.TARGET_MATCH, spec_method="ngram") + m = _create_metadata(batch_size=1) + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + probs = paddle.nn.functional.softmax(logits, axis=-1) + seeds = paddle.ones([probs.shape[0], 1], dtype="int64") + + out = sampler._verify_and_sample( + logits, + probs, + m, + max_model_len=8, + share_inputs=_spec_share_inputs(), + token_num_output_cpu=1, + increment_value=1, + topp_seed=seeds, + ) + assert out.sampled_token_ids is not None + + def test_normal_sample_triton(self, mock_ops, triton_mode, monkeypatch): + """_normal_sample with triton → calls _random_sample (line 1016-1017).""" + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.naive_update_model_status", + lambda *a, **k: None, + ) + + sampler = _make_spec_sampler(spec_method=None) # None → NAIVE + m = _create_metadata(batch_size=1) + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + probs = paddle.nn.functional.softmax(logits, axis=-1) + seeds = paddle.ones([probs.shape[0], 1], dtype="int64") + + out = sampler._normal_sample(logits, probs, m, share_inputs=_spec_share_inputs(), topp_seed=seeds) + assert out.sampled_token_ids is not None + + def test_forward_cuda_triton_logit_mask(self, mock_ops, triton_mode, monkeypatch): + """SpeculativeSampler.forward_cuda with triton → masks logits (lines 1120-1132).""" + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.build_sampling_params", + lambda *a, **k: ( + paddle.to_tensor([[0.9]], dtype="float32"), + paddle.to_tensor([[5]], dtype="int64"), + paddle.to_tensor([[7]], dtype="int64"), + ), + ) + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.sampler.naive_update_model_status", + lambda *a, **k: None, + ) + + sampler = _make_spec_sampler(spec_method=None) # NAIVE → _normal_sample + m = _create_metadata(batch_size=1) + logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32") + + out = sampler.forward_cuda( + logits, + m, + max_model_len=8, + share_inputs=_spec_share_inputs(), + token_num_output_cpu=1, + increment_value=1, + ) + assert out.sampled_token_ids is not None + + +# --------------------------------------------------------------------------- +# Tests for triton Python wrapper functions (top_k_top_p_triton.py coverage) +# --------------------------------------------------------------------------- + + +class TestTritonWrapperFunctions: + """Cover the Python wrapper functions in top_k_top_p_triton.py.""" + + def test_reset_buffer_cache(self, monkeypatch): + """reset_buffer_cache should run without error.""" + from fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton import ( + reset_buffer_cache, + ) + + monkeypatch.setattr( + "fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton.paddle.accelerator", + types.SimpleNamespace(empty_cache=lambda: None), + raising=False, + ) + reset_buffer_cache() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/layers/test_trtllm_allreduce_rms_fusion.py b/tests/layers/test_trtllm_allreduce_rms_fusion.py new file mode 100644 index 00000000000..8edd007cadd --- /dev/null +++ b/tests/layers/test_trtllm_allreduce_rms_fusion.py @@ -0,0 +1,54 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import subprocess +import sys + + +def test_run_distributed(): + """Launch multi-GPU distributed test via paddle.distributed.launch as subprocess""" + # flashinfer_cache_dir = os.path.expanduser("~/.cache/flashinfer") + # if os.path.exists(flashinfer_cache_dir): + # print(f"=== Clearing flashinfer cache directory: {flashinfer_cache_dir} ===") + # subprocess.run(["rm", "-rf", flashinfer_cache_dir], check=True) + current_dir = os.path.dirname(os.path.abspath(__file__)) + run_script = os.path.join(current_dir, "trtllm_allreduce_rms_fusion.py") + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + command = [ + sys.executable, + "-m", + "paddle.distributed.launch", + "--gpus", + "0,1", + run_script, + ] + + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + try: + stdout, stderr = process.communicate(timeout=400) + return_code = process.returncode + except subprocess.TimeoutExpired: + process.kill() + stdout, stderr = process.communicate() + return_code = -1 + print(f"=== Distributed test stdout ===\n{stdout}") + print(f"=== Distributed test stderr ===\n{stderr}") + assert return_code in (0, 250), f"Process exited with code {return_code}" + + +test_run_distributed() diff --git a/tests/layers/trtllm_allreduce_rms_fusion.py b/tests/layers/trtllm_allreduce_rms_fusion.py new file mode 100644 index 00000000000..117e2edbe32 --- /dev/null +++ b/tests/layers/trtllm_allreduce_rms_fusion.py @@ -0,0 +1,845 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import time +import unittest +from unittest.mock import Mock, patch + +import numpy as np +import paddle +import paddle.distributed as dist + + +class TestFlashInferAllReduceResidualRMSNorm(unittest.TestCase): + """Test FlashInfer AllReduce + Residual + RMSNorm fused operator""" + + @classmethod + def setUpClass(cls): + """Set up test environment""" + if paddle.is_compiled_with_cuda(): + # Bind each rank to its own GPU explicitly; otherwise all ranks + # default to "gpu:0" and cudaIpcOpenMemHandle fails with + # "invalid device context". + local_rank = int( + os.environ.get("PADDLE_LOCAL_RANK", os.environ.get("FLAGS_selected_gpus", "0").split(",")[0]) + ) + paddle.set_device(f"gpu:{local_rank}") + + # paddle.distributed.launch remaps each rank's visible GPU to + # index 0 inside the worker process. flashinfer's IPC calls go + # through the cudart runtime API (cuda-python), which maintains + # its own primary context separate from Paddle's driver context. + # Explicitly activate cudart's primary context on device 0 here, + # otherwise cudaIpcOpenMemHandle reports "invalid device context". + try: + from cuda import cudart + + cudart.cudaSetDevice(0) + cudart.cudaFree(0) # force primary context creation + except ImportError: + pass + else: + paddle.set_device("cpu") + dist.init_parallel_env() + if paddle.is_compiled_with_cuda(): + # Force the CUDA primary context to be created on the current + # device before flashinfer's cudart IPC calls run. + paddle.zeros([1]).cuda() + paddle.device.cuda.synchronize() + + def setUp(self): + """Initialize each test case""" + # Fix random seed for reproducibility + paddle.seed(42) + np.random.seed(42) + + # NOTE: switched fp32 -> bf16 to mirror real model dtype on B GPUs. + # Combined with use_oneshot=None below, this exercises the bf16 + + # oneshot Lamport path, which is the suspected garbled-output path + # on Blackwell (sm100). + self.dtype = paddle.bfloat16 + self.token_num = 128 + self.hidden_dim = 4096 + self.eps = 1e-6 + self.epsilon = 1e-6 + self.max_token_num = 2048 + + # Create mock FDConfig + self.fd_config = Mock() + self.fd_config.parallel_config = Mock() + self.fd_config.parallel_config.tensor_parallel_size = dist.get_world_size() + self.begin_norm_axis = 1 + + # Performance test params - increase iterations for stability + self.warmup_iterations = 20 # Increase warmup + self.test_iterations = 200 # Increase test iterations + + def tearDown(self): + """Clean up resources""" + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.empty_cache() + paddle.device.cuda.synchronize() + + def create_test_tensors(self): + """Create test tensors""" + input_tensor = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) + residual = paddle.randn([self.token_num, self.hidden_dim], dtype=self.dtype) + weight = paddle.randn([self.hidden_dim], dtype=self.dtype) + return input_tensor, residual, weight + + def compute_reference_output(self, input_tensor, residual, weight, eps): + """Reference implementation: manually compute AllReduce + Residual + RMSNorm""" + # # Step 1: AllReduce (identity on single device) + # allreduce_out = input_tensor.clone() + # Apply all reduce operator + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + # Step 2: Add residual + residual_out = input_tensor + residual + + # Step 3: RMSNorm + variance = residual_out.pow(2).mean(axis=-1, keepdim=True) + norm_out = residual_out * paddle.rsqrt(variance + eps) + norm_out = norm_out * weight + + # dist.all_reduce(residual_out, op=dist.ReduceOp.SUM) + return norm_out, residual_out + + def paddle_rms_fuse(self, input_tensor, residual, weight, eps): + from paddle.incubate.nn.functional import fused_rms_norm + + # Apply all reduce operator + dist.all_reduce(input_tensor, op=dist.ReduceOp.SUM) + out_fused = fused_rms_norm( + input_tensor, + norm_weight=weight, + norm_bias=None, + epsilon=eps, + begin_norm_axis=self.begin_norm_axis, + bias=None, + residual=residual, + ) + + return out_fused[0], out_fused[1] + + def flashinfer_rms_fuse(self, input_tensor, residual, weight, eps): + """FlashInfer fused operator""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=eps, + max_token_num=self.max_token_num, + # NOTE: do NOT pass use_oneshot=False here. We want the auto path + # (use_oneshot=None) so the oneshot Lamport kernel is exercised, + # matching how normalization.py calls it in the real model. + ) + return norm_out, residual_out + + def benchmark_function(self, func, *args, name="", **kwargs): + """ + Improved performance benchmark + - Wait for GPU frequency stabilization + - Use median instead of mean (more stable) + - Filter outliers + """ + # Force GPU frequency stabilization + if paddle.is_compiled_with_cuda(): + for _ in range(5): + paddle.device.cuda.synchronize() + time.sleep(0.01) + + # Warmup - thorough warm-up + for _ in range(self.warmup_iterations): + result = func(*args, **kwargs) + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + # Extra wait to ensure GPU stability + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + time.sleep(0.1) + + # Benchmark run + times = [] + for i in range(self.test_iterations): + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + start = time.perf_counter() + result = func(*args, **kwargs) + + if paddle.is_compiled_with_cuda(): + paddle.device.cuda.synchronize() + + end = time.perf_counter() + elapsed = (end - start) * 1000 # Convert to milliseconds + times.append(elapsed) + + times = np.array(times) + + # Filter outliers using IQR method + q1, q3 = np.percentile(times, [25, 75]) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + filtered_times = times[(times >= lower_bound) & (times <= upper_bound)] + + # Fall back to raw data if too many samples filtered out + if len(filtered_times) < self.test_iterations * 0.5: + filtered_times = times + + # Statistics + avg_time = np.mean(filtered_times) + median_time = np.median(filtered_times) + std_time = np.std(filtered_times) + min_time = np.min(filtered_times) + max_time = np.max(filtered_times) + cv = (std_time / avg_time) * 100 # Coefficient of variation (%) + + print(f"\n{'='*70}") + print(f"Performance Benchmark: {name}") + print(f"{'='*70}") + print(f"Iterations: {len(filtered_times)}/{self.test_iterations} (after {self.warmup_iterations} warmup)") + print(f"Median: {median_time:.4f} ms (most stable metric)") + print(f"Average: {avg_time:.4f} ms") + print(f"Std Dev: {std_time:.4f} ms (CV: {cv:.2f}%)") + print(f"Min: {min_time:.4f} ms") + print(f"Max: {max_time:.4f} ms") + print(f"{'='*70}\n") + + # Return median (more stable) and result + return median_time, result + + def test_accuracy_fused_vs_reference(self): + """Test accuracy of fused operator vs reference implementation""" + input_tensor, residual, weight = self.create_test_tensors() + reference_output, ref_res = self.compute_reference_output( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + fused_output, paddle_res = self.paddle_rms_fuse( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + flashinfer_output, flashinfer_res = self.flashinfer_rms_fuse( + input_tensor.clone(), residual.clone(), weight.clone(), self.eps + ) + + # bf16 needs much looser tolerance than fp32. Cast to fp32 for + # comparison to avoid numpy bf16 issues. + if self.dtype == paddle.bfloat16: + rtol, atol = 5e-2, 5e-2 + to_np = lambda t: t.astype("float32").numpy() # noqa: E731 + else: + rtol, atol = 1e-5, 1e-5 + to_np = lambda t: t.numpy() # noqa: E731 + + # Verify results + np.testing.assert_allclose(to_np(fused_output), to_np(reference_output), rtol=rtol, atol=atol) + np.testing.assert_allclose(to_np(ref_res), to_np(paddle_res), rtol=rtol, atol=atol) + np.testing.assert_allclose(to_np(flashinfer_output), to_np(reference_output), rtol=rtol, atol=atol) + np.testing.assert_allclose(to_np(ref_res), to_np(flashinfer_res), rtol=rtol, atol=atol) + + +class TestFlashInferWorkspaceManager(unittest.TestCase): + """Test FlashInferWorkspaceManager""" + + def setUp(self): + """Initialize""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + self.manager = FlashInferWorkspaceManager() + + def test_initialization(self): + """Test initialization state""" + self.assertIsNone(self.manager.workspace_tensor) + self.assertIsNone(self.manager.ipc_handles) + self.assertIsNone(self.manager.world_size) + self.assertIsNone(self.manager.rank) + self.assertFalse(self.manager.initialized) + + def test_cleanup(self): + """Test cleanup functionality""" + self.manager.cleanup() + self.assertFalse(self.manager.initialized) + self.assertIsNone(self.manager.workspace_tensor) + + +class TestFlashInferWorkspaceManagerEdgeCases(unittest.TestCase): + """Test FlashInferWorkspaceManager edge cases and fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + # Patch before importing to test fallback paths + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_initialization_early_return_when_already_initialized(self): + """Test line 47: early return when already initialized with same world_size""" + # Patch _flashinfer_comm to be available + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + + # First initialization + manager.initialized = True + manager.world_size = 2 + + # Mock the comm functions + mock_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion = Mock(return_value=(Mock(), Mock())) + + # Second initialization with same world_size - should return early + manager.initialize( + world_size=2, + rank=0, + max_token_num=2048, + hidden_dim=4096, + ) + + def test_initialization_warning_when_comm_none(self): + """Test lines 50-51: warning when _flashinfer_comm is None""" + # Patch to ensure _get_flashinfer_comm returns None + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm", + return_value=None, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + + # Should not raise, just log warning and return + manager.initialize( + world_size=2, + rank=0, + max_token_num=2048, + hidden_dim=4096, + ) + + # Verify not initialized + self.assertFalse(manager.initialized) + + def test_cleanup_with_exception(self): + """Test lines 73-80: cleanup with exception handling""" + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + manager.initialized = True + manager.ipc_handles = Mock() + manager.workspace_tensor = Mock() + + # Mock the destroy function to raise exception + mock_comm.trtllm_destroy_ipc_workspace_for_all_reduce = Mock(side_effect=RuntimeError("Cleanup error")) + + # Should not raise, just log warning + manager.cleanup() + + # Verify cleanup happened + self.assertFalse(manager.initialized) + self.assertIsNone(manager.workspace_tensor) + self.assertIsNone(manager.ipc_handles) + + def test_cleanup_without_initialization(self): + """Test cleanup when not initialized""" + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + FlashInferWorkspaceManager, + ) + + manager = FlashInferWorkspaceManager() + manager.initialized = False + + # Should not raise + manager.cleanup() + + # Verify state + self.assertFalse(manager.initialized) + + +class TestEnsureWorkspaceInitialized(unittest.TestCase): + """Test ensure_workspace_initialized fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_ensure_workspace_when_flashinfer_not_available(self): + """Test line 91: early return when flashinfer not available""" + self.mock_has_flashinfer.return_value = False + + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + result = ensure_workspace_initialized(fd_config) + + # Should return False (not initialized) + self.assertFalse(result) + + def test_ensure_workspace_when_comm_none(self): + """Test ensure_workspace_initialized when _flashinfer_comm is None""" + self.mock_has_flashinfer.return_value = True + + with patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion._get_flashinfer_comm", + return_value=None, + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + result = ensure_workspace_initialized(fd_config) + + # Should return False + self.assertFalse(result) + + def test_ensure_workspace_single_gpu(self): + """Test line 96: early return when world_size <= 1""" + self.mock_has_flashinfer.return_value = True + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + ensure_workspace_initialized, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 1 + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.dist.get_rank", return_value=0): + result = ensure_workspace_initialized(fd_config) + + # Should return False for single GPU + self.assertFalse(result) + + +class TestFlashInferAllReduceResidualRMSNormFallbacks(unittest.TestCase): + """Test flashinfer_allreduce_residual_rmsnorm fallback paths""" + + def setUp(self): + """Initialize test fixtures""" + self.patcher_has_flashinfer = patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion.has_flashinfer") + self.mock_has_flashinfer = self.patcher_has_flashinfer.start() + + def tearDown(self): + """Clean up patches""" + self.patcher_has_flashinfer.stop() + + def test_flashinfer_not_available_fallback(self): + """Test lines 140-141: fallback when flashinfer not available""" + self.mock_has_flashinfer.return_value = False + + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return None, None when flashinfer not available + self.assertIsNone(norm_out) + self.assertIsNone(residual_out) + + def test_single_gpu_fallback(self): + """Test lines 146-147: fallback for single GPU""" + self.mock_has_flashinfer.return_value = True + + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm"): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 1 + + input_tensor = paddle.randn([128, 768]) + residual = paddle.randn([128, 768]) + weight = paddle.randn([768]) + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return None, None for single GPU + self.assertIsNone(norm_out) + self.assertIsNone(residual_out) + + def test_empty_tensor_handling(self): + """Test line 166: empty tensor handling""" + self.mock_has_flashinfer.return_value = True + + with ( + patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._flashinfer_comm") as mock_comm, + patch( + "fastdeploy.model_executor.layers.flashinfer_comm_fusion.ensure_workspace_initialized", + return_value=True, + ), + ): + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = 2 + + # Empty tensor (0 tokens) + input_tensor = paddle.zeros([0, 768]) + residual = paddle.zeros([0, 768]) + weight = paddle.randn([768]) + + # Mock the trtllm_allreduce_fusion to not be called + mock_comm.trtllm_allreduce_fusion = Mock() + + norm_out, residual_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=fd_config, + input_tensor=input_tensor, + residual=residual, + weight=weight, + eps=1e-6, + max_token_num=2048, + ) + + # Should return empty tensors, not call flashinfer + self.assertEqual(norm_out.shape[0], 0) + self.assertEqual(residual_out.shape[0], 0) + mock_comm.trtllm_allreduce_fusion.assert_not_called() + + +class TestCleanupFlashInferWorkspace(unittest.TestCase): + """Test cleanup_flashinfer_workspace function""" + + def test_cleanup_workspace_function(self): + """Test lines 211-212: cleanup function""" + with patch("fastdeploy.model_executor.layers.flashinfer_comm_fusion._workspace_manager") as mock_manager: + from fastdeploy.model_executor.layers.flashinfer_comm_fusion import ( + cleanup_flashinfer_workspace, + ) + + mock_manager.cleanup = Mock() + + cleanup_flashinfer_workspace() + + mock_manager.cleanup.assert_called_once() + + +class TestRMSNormProxyAllreduceFused(unittest.TestCase): + @classmethod + def setUpClass(cls): + # The outer test_run_distributed in test_trtllm_allreduce_rms_fusion.py + # has already done paddle.set_device + init_parallel_env, so we don't + # repeat that here. (unittest.main runs in the same process.) + cls.tp_size = dist.get_world_size() + cls.tp_rank = dist.get_rank() + + def _make_fd_config(self, enable_fusion: bool): + """Mock fd_config with the minimal attributes RMSNorm.__init__ touches.""" + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = self.tp_size + fd_config.parallel_config.tensor_parallel_rank = self.tp_rank + fd_config.parallel_config.tp_group = dist.get_group() + fd_config.parallel_config.expert_parallel_size = 1 + fd_config.parallel_config.enable_flashinfer_allreduce_fusion = enable_fusion + fd_config.parallel_config.use_sequence_parallel_moe = False + fd_config.model_config = Mock() + fd_config.model_config.moe_layer_start_index = -1 + fd_config.quant_config = None + return fd_config + + def _build_rmsnorm(self, enable_fusion: bool, hidden_size: int, layer_id: int = 1): + """Build a real RMSNorm whose enable_all_reduce_fusion resolves to + `enable_fusion` (use post_attention_layernorm prefix to ensure the + prefix-match in __init__ passes).""" + from fastdeploy.model_executor.layers.normalization import RMSNorm + + fd_config = self._make_fd_config(enable_fusion=enable_fusion) + norm = RMSNorm( + fd_config=fd_config, + hidden_size=hidden_size, + eps=1e-6, + prefix=f"model.layers.{layer_id}.post_attention_layernorm", + layer_id=layer_id, + dtype="bfloat16", + ) + # Initialize weight to a known reproducible value (constant=1.0 by default). + with paddle.no_grad(): + paddle.seed(2024) + new_w = paddle.randn([hidden_size], dtype=paddle.bfloat16) + dist.broadcast(new_w, src=0) + norm.weight.set_value(new_w) + return norm + + @staticmethod + def _proxy_rmsnorm_fn(x, weight, eps): + """Stand-in for phi rmsnorm used as proxy_rmsnorm — standard formula + in fp32 to keep reference numerics clean.""" + x_fp32 = x.astype("float32") + var = x_fp32.pow(2).mean(axis=-1, keepdim=True) + out = x_fp32 * paddle.rsqrt(var + eps) + out = out * weight.astype("float32") + return out.astype(x.dtype) + + def _reference(self, x_partial, residual, weight, eps): + """Manual: all_reduce(x_partial) + residual, then standard RMSNorm. + Mirrors what proxy path WOULD produce after explicit allreduce+add.""" + x = x_partial.clone() + dist.all_reduce(x, op=dist.ReduceOp.SUM) + residual_out = x + residual + norm_out = self._proxy_rmsnorm_fn(residual_out, weight, eps) + return norm_out, residual_out + + def _make_inputs(self, token_num, hidden_size, seed=123): + """Each rank gets a different x_partial (simulates RowParallelLinear's + un-reduced output); residual is identical across ranks.""" + paddle.seed(seed + self.tp_rank * 7919) + x_partial = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) * 0.1 + paddle.seed(seed + 99) + residual = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + return x_partial, residual + + def _assert_close_bf16(self, a, b, rtol=5e-2, atol=5e-2, msg=""): + a32 = a.astype("float32").numpy() + b32 = b.astype("float32").numpy() + np.testing.assert_allclose(a32, b32, rtol=rtol, atol=atol, err_msg=msg) + + # ---------- Tests ---------- + + def test_proxy_path_takes_fused_branch(self): + """fusion=on, tp>1, shape<=2048, residual!=None + -> proxy branch picks flashinfer_allreduce_residual_rmsnorm. + Verify by patching the symbol and asserting it was called. + """ + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 512 + norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden) + self.assertTrue(norm.enable_all_reduce_fusion) + x_partial, residual = self._make_inputs(token_num=64, hidden_size=hidden) + + # Patch within the normalization module's namespace. + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm", + wraps=__import__( + "fastdeploy.model_executor.layers.normalization", fromlist=["flashinfer_allreduce_residual_rmsnorm"] + ).flashinfer_allreduce_residual_rmsnorm, + ) as spy: + out, res = norm.forward( + x_partial.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=self._proxy_rmsnorm_fn, + ) + spy.assert_called_once() + + # Numerics: must match reference (allreduce + add + std rmsnorm). + ref_norm, ref_res = self._reference(x_partial, residual, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="proxy fused-branch norm output mismatch") + self._assert_close_bf16(res, ref_res, msg="proxy fused-branch residual mismatch") + + def test_proxy_path_falls_back_when_fusion_disabled(self): + """fusion=off -> proxy branch must call proxy_rmsnorm directly, + no fused allreduce path used. Input is treated as already-reduced.""" + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 512 + norm = self._build_rmsnorm(enable_fusion=False, hidden_size=hidden) + self.assertFalse(norm.enable_all_reduce_fusion) + + # Each rank uses the SAME x (already-reduced) — that's the contract + # when fusion is off (RowParallelLinear has done its own allreduce). + paddle.seed(777) + x = paddle.randn([64, hidden], dtype=paddle.bfloat16) * 0.1 + dist.broadcast(x, src=0) + residual = paddle.randn([64, hidden], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + + proxy_called = {"n": 0} + + def proxy_spy(_x, _w, _eps): + proxy_called["n"] += 1 + return self._proxy_rmsnorm_fn(_x, _w, _eps) + + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm" + ) as fused_spy: + out, res = norm.forward( + x.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=proxy_spy, + ) + fused_spy.assert_not_called() + + self.assertEqual(proxy_called["n"], 1, "proxy_rmsnorm must be invoked exactly once") + + # Reference: x is already full -> just add + rmsnorm, no allreduce. + residual_full = x + residual + ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="fallback norm output mismatch") + self._assert_close_bf16(res, residual_full, msg="fallback residual mismatch") + + def test_proxy_path_falls_back_when_token_too_large(self): + """fusion=on but shape[0] > 2048 -> proxy branch must NOT call fused; + in this regime upstream RowParallelLinear didn't skip its own + all-reduce, so x is already full and proxy_rmsnorm is invoked directly.""" + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 256 + norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden) + # shape[0] > 2048 forces use_allreduce_fused=False + token_num = 2049 + paddle.seed(555) + x = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) * 0.1 + dist.broadcast(x, src=0) + residual = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm" + ) as fused_spy: + out, res = norm.forward( + x.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=self._proxy_rmsnorm_fn, + ) + fused_spy.assert_not_called() + + residual_full = x + residual + ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="large-shape fallback norm mismatch") + self._assert_close_bf16(res, residual_full, msg="large-shape fallback residual mismatch") + + +class TestGlm4MoeMLPInit(unittest.TestCase): + """Cover Glm4MoeMLP.__init__ attribute assignments (glm4_moe.py:67-71).""" + + def _make_fd_config(self, tp_size, ep_size, enable_fusion, use_seq_parallel_moe=False): + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = tp_size + fd_config.parallel_config.expert_parallel_size = ep_size + fd_config.parallel_config.enable_flashinfer_allreduce_fusion = enable_fusion + fd_config.parallel_config.use_sequence_parallel_moe = use_seq_parallel_moe + fd_config.model_config = Mock() + fd_config.model_config.hidden_size = 64 + fd_config.model_config.hidden_act = "silu" + fd_config.model_config.moe_layer_start_index = 0 + return fd_config + + def _build(self, fd_config, layer_id=1): + # Patch heavy submodules so Glm4MoeMLP.__init__ runs without real deps. + with ( + patch("fastdeploy.model_executor.models.glm4_moe.MergedColumnParallelLinear"), + patch("fastdeploy.model_executor.models.glm4_moe.MergedReplicatedLinear"), + patch("fastdeploy.model_executor.models.glm4_moe.RowParallelLinear"), + patch("fastdeploy.model_executor.models.glm4_moe.ReplicatedLinear"), + patch("fastdeploy.model_executor.models.glm4_moe.SiluAndMul"), + ): + from fastdeploy.model_executor.models.glm4_moe import Glm4MoeMLP + + return Glm4MoeMLP( + fd_config=fd_config, + intermediate_size=128, + layer_id=layer_id, + prefix="model.layers.1.mlp", + ) + + def test_tp_only_fusion_enabled(self): + """tp>1, ep=1, fusion=True -> enable_all_reduce_fusion=True.""" + fd_config = self._make_fd_config(tp_size=4, ep_size=1, enable_fusion=True) + mlp = self._build(fd_config) + self.assertEqual(mlp.expert_parallel_size, 1) + self.assertEqual(mlp.tensor_parallel_size, 4) + self.assertTrue(mlp.use_tp) + self.assertFalse(mlp.use_ep) + self.assertTrue(mlp.enable_all_reduce_fusion) + + def test_ep_disables_fusion(self): + """ep>1 -> enable_all_reduce_fusion forced False even if flag is True.""" + fd_config = self._make_fd_config(tp_size=2, ep_size=2, enable_fusion=True) + mlp = self._build(fd_config) + self.assertTrue(mlp.use_tp) + self.assertTrue(mlp.use_ep) + self.assertFalse(mlp.enable_all_reduce_fusion) + + def test_single_gpu_no_fusion(self): + """tp=1, ep=1 -> use_tp/use_ep False, fusion False.""" + fd_config = self._make_fd_config(tp_size=1, ep_size=1, enable_fusion=True) + mlp = self._build(fd_config) + self.assertFalse(mlp.use_tp) + self.assertFalse(mlp.use_ep) + self.assertFalse(mlp.enable_all_reduce_fusion) + + def test_fusion_flag_off(self): + """flag False -> enable_all_reduce_fusion False regardless of tp.""" + fd_config = self._make_fd_config(tp_size=4, ep_size=1, enable_fusion=False) + mlp = self._build(fd_config) + self.assertTrue(mlp.use_tp) + self.assertFalse(mlp.enable_all_reduce_fusion) + + +if __name__ == "__main__": + """Run tests directly (called by subprocess after distributed launch)""" + unittest.main(verbosity=2) diff --git a/tests/metrics/test_benchmark_metrics_logger.py b/tests/metrics/test_benchmark_metrics_logger.py new file mode 100644 index 00000000000..4f291327ad0 --- /dev/null +++ b/tests/metrics/test_benchmark_metrics_logger.py @@ -0,0 +1,499 @@ +""" +Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import os +import time +from unittest.mock import MagicMock, patch + +import pytest + +from fastdeploy.config import BenchmarkMetricsConfig, FDConfig +from fastdeploy.metrics.benchmark_metrics_logger import ( + BenchmarkMetricsLogger, + CompletedRequestRecord, +) + + +def _make_record(request_id, now, offset, input_len=100, output_len=50): + return CompletedRequestRecord( + request_id=request_id, + completion_time=now + offset, + arrival_time=now + offset - 0.05, + inference_start_time=now + offset - 0.04, + first_token_time=now + offset - 0.02, + last_token_time=now + offset, + input_len=input_len, + output_len=output_len, + itl_samples=[0.02, 0.021, 0.019], + ) + + +def test_config_defaults(): + config = BenchmarkMetricsConfig(None) + assert config.enable is False + assert config.window_size == 0 + assert config.window_mode == "sliding" + assert config.percentile_values == [50.0, 90.0, 95.0, 99.0] + assert config.selected_metrics == set(BenchmarkMetricsConfig._ALL_METRICS) + + +def test_config_custom(): + config = BenchmarkMetricsConfig( + {"enable": True, "window_size": 200, "window_mode": "tumbling", "percentiles": "50,99", "metrics": "ttft,e2el"} + ) + assert config.enable is True + assert config.window_size == 200 + assert config.window_mode == "tumbling" + assert config.percentile_values == [50.0, 99.0] + assert config.selected_metrics == {"ttft", "e2el"} + + +def test_config_empty_dict(): + config = BenchmarkMetricsConfig({}) + assert config.enable is False + assert config.window_size == 0 + assert config.window_mode == "sliding" + assert config.percentile_values == [50.0, 90.0, 95.0, 99.0] + + +def test_config_enable_only(): + config = BenchmarkMetricsConfig({"enable": True}) + assert config.enable is True + assert config.window_mode == "sliding" + + +def test_logger_writes_jsonl(tmp_path): + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "percentiles": "50,99", "metrics": "ttft,e2el"}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + now = time.time() + for i in range(5): + logger.on_request_completed(_make_record(f"req-{i}", now, i * 0.1)) + + time.sleep(0.5) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + assert os.path.exists(jsonl_path) + + with open(jsonl_path) as f: + lines = f.readlines() + + assert len(lines) == 5 + + last_record = json.loads(lines[-1]) + assert last_record["completed"] == 5 + assert "ttft_ms" in last_record + assert "e2el_ms" in last_record + assert "tpot_ms" not in last_record + assert last_record["ttft_ms"]["mean"] > 0 + + +def test_logger_sliding_window(tmp_path): + """Sliding window: keeps the last N records, never clears.""" + config = BenchmarkMetricsConfig( + {"enable": True, "window_size": 3, "window_mode": "sliding", "percentiles": "50", "metrics": "all"} + ) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + now = time.time() + for i in range(5): + logger.on_request_completed(_make_record(f"req-{i}", now, i)) + + time.sleep(0.5) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + with open(jsonl_path) as f: + lines = f.readlines() + + assert len(lines) == 5 + + # After 5 records with window_size=3, the window always has at most 3 + rec3 = json.loads(lines[2]) # 3rd record: window full (3 records) + assert rec3["completed"] == 3 + + rec4 = json.loads(lines[3]) # 4th record: still 3 (oldest dropped) + assert rec4["completed"] == 3 + + last_record = json.loads(lines[-1]) + assert last_record["completed"] == 3 + assert last_record["window_size"] == 3 + assert last_record["window_mode"] == "sliding" + + +def test_logger_tumbling_window(tmp_path): + """Tumbling window: clears after reaching window_size, then restarts.""" + config = BenchmarkMetricsConfig( + {"enable": True, "window_size": 3, "window_mode": "tumbling", "percentiles": "50", "metrics": "all"} + ) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + now = time.time() + for i in range(5): + logger.on_request_completed(_make_record(f"req-{i}", now, i)) + + time.sleep(0.5) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + with open(jsonl_path) as f: + lines = f.readlines() + + assert len(lines) == 5 + + # Records 1,2,3 accumulate then clear; records 4,5 start fresh + rec1 = json.loads(lines[0]) + assert rec1["completed"] == 1 + + rec3 = json.loads(lines[2]) # 3rd record: window full (3 records), then clears + assert rec3["completed"] == 3 + + rec4 = json.loads(lines[3]) # 4th record: window restarted, 1 record + assert rec4["completed"] == 1 + + rec5 = json.loads(lines[4]) # 5th record: 2 records in new window + assert rec5["completed"] == 2 + assert rec5["window_mode"] == "tumbling" + + +def test_logger_no_output_when_no_requests(tmp_path): + config = BenchmarkMetricsConfig({"enable": True}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + time.sleep(0.3) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + assert os.path.exists(jsonl_path) + with open(jsonl_path) as f: + content = f.read() + assert content == "" + + +def test_logger_enabled_flag(tmp_path): + """Logger with enable=False should have enabled=False.""" + config = BenchmarkMetricsConfig({"enable": False}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + assert logger.enabled is False + logger.shutdown() + + +def test_logger_enabled_true(tmp_path): + """Logger with enable=True should have enabled=True.""" + config = BenchmarkMetricsConfig({"enable": True}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + assert logger.enabled is True + logger.shutdown() + + +def test_stats_computation(): + stats = BenchmarkMetricsLogger._stats([10.0, 20.0, 30.0, 40.0, 50.0], [50.0, 99.0]) + assert stats["mean"] == 30.0 + assert stats["median"] == 30.0 + assert "p50" in stats + assert "p99" in stats + assert stats["p50"] == 30.0 + + +def test_stats_empty_list(): + stats = BenchmarkMetricsLogger._stats([], [50.0]) + assert stats == {} + + +def test_throughput_in_output(tmp_path): + """Throughput fields should appear when there are 2+ records.""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "percentiles": "50", "metrics": "ttft"}) + log_dir = str(tmp_path) + logger = BenchmarkMetricsLogger(config=config, log_dir=log_dir, dp_rank=0) + + now = time.time() + for i in range(3): + logger.on_request_completed(_make_record(f"req-{i}", now, i * 0.5)) + + time.sleep(0.5) + logger.shutdown() + + jsonl_path = os.path.join(log_dir, "benchmark_metrics.jsonl") + with open(jsonl_path) as f: + lines = f.readlines() + + # First record has no throughput (only 1 sample, duration=0) + rec1 = json.loads(lines[0]) + assert "request_throughput" not in rec1 + + # Last record should have throughput + last = json.loads(lines[-1]) + assert "request_throughput" in last + assert "output_throughput" in last + assert "total_throughput" in last + assert last["request_throughput"] > 0 + + +# ============================================================ +# Validation tests (via FDConfig.check()) +# ============================================================ + + +def _make_fd_config_with_benchmark(benchmark_cfg): + """Create a mock FDConfig with valid base attributes, only benchmark_metrics_config is real.""" + cfg = object.__new__(FDConfig) + # Mock all attributes accessed by check() before benchmark validation + cfg.scheduler_config = MagicMock() + cfg.scheduler_config.max_num_seqs = 128 + cfg.scheduler_config.max_num_batched_tokens = 8192 + cfg.scheduler_config.splitwise_role = "mixed" + cfg.scheduler_config.check = MagicMock() + cfg.model_config = MagicMock() + cfg.model_config.max_model_len = 8192 + cfg.cache_config = MagicMock() + cfg.cache_config.enable_chunked_prefill = True + cfg.cache_config.block_size = 64 + cfg.speculative_config = None + cfg.eplb_config = None + cfg.structured_outputs_config = None + cfg.graph_opt_config = MagicMock() + cfg.graph_opt_config.graph_opt_level = 0 + cfg.nnode = 1 + cfg.max_num_partial_prefills = 1 + cfg.max_long_partial_prefills = 1 + cfg.long_prefill_token_threshold = 0 + cfg.benchmark_metrics_config = benchmark_cfg + return cfg + + +@patch("fastdeploy.config.envs") +def test_valid_config_passes_check(mock_envs): + """Valid configs should pass FDConfig.check() without errors.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + configs = [ + {"enable": True}, + {"enable": True, "window_size": 64, "window_mode": "tumbling"}, + {"enable": False, "window_size": 0, "window_mode": "sliding"}, + {"enable": True, "percentiles": "50,90,99", "metrics": "ttft,e2el,s_decode"}, + ] + for args in configs: + benchmark_cfg = BenchmarkMetricsConfig(args) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + fd_cfg.check() # Should not raise + + +@patch("fastdeploy.config.envs") +def test_invalid_enable(mock_envs): + """enable must be a bool.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": "true"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="'enable' must be a bool"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_window_size_negative(mock_envs): + """window_size must be non-negative.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "window_size": -1}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="'window_size' must be a non-negative integer"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_window_size_type(mock_envs): + """window_size must be an integer.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "window_size": 3.5}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="'window_size' must be a non-negative integer"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_window_mode(mock_envs): + """window_mode must be 'sliding' or 'tumbling'.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "window_mode": "fixed"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="'window_mode' must be 'sliding' or 'tumbling'"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_percentile_out_of_range(mock_envs): + """Percentile values must be in [0, 100].""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "percentiles": "50,101"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="percentile value .* out of range"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_percentile_negative(mock_envs): + """Percentile values must be >= 0.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "percentiles": "-1,50"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="percentile value .* out of range"): + fd_cfg.check() + + +@patch("fastdeploy.config.envs") +def test_invalid_metrics_unknown(mock_envs): + """Unknown metric names should fail validation.""" + mock_envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 + benchmark_cfg = BenchmarkMetricsConfig({"enable": True, "metrics": "ttft,unknown_metric"}) + fd_cfg = _make_fd_config_with_benchmark(benchmark_cfg) + with pytest.raises(AssertionError, match="unknown metric"): + fd_cfg.check() + + +# ============================================================ +# Direct method tests (bypass daemon thread for coverage) +# ============================================================ + + +def test_process_pending_direct(tmp_path): + """Directly call _process_pending to cover lines 98-109.""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "metrics": "all", "percentiles": "50,99"}) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + now = time.time() + # Add records directly to _pending without relying on background thread + for i in range(3): + logger._pending.append(_make_record(f"req-{i}", now, i * 0.5)) + + # Call _process_pending directly from main thread (coverage-tracked) + logger._process_pending() + + assert len(logger._window) == 3 + logger.shutdown() + + jsonl_path = os.path.join(str(tmp_path), "benchmark_metrics.jsonl") + with open(jsonl_path) as f: + lines = f.readlines() + assert len(lines) == 3 + rec = json.loads(lines[-1]) + assert rec["completed"] == 3 + assert "ttft_ms" in rec + assert "tpot_ms" in rec + assert "e2el_ms" in rec + assert "s_ttft_ms" in rec + assert "s_e2el_ms" in rec + assert "s_decode" in rec + assert "input_len" in rec + assert "s_input_len" in rec + assert "output_len" in rec + assert "request_throughput" in rec + assert "output_throughput" in rec + assert "total_throughput" in rec + + +def test_process_pending_tumbling_clear(tmp_path): + """Tumbling window clears after reaching window_size via direct call.""" + config = BenchmarkMetricsConfig( + {"enable": True, "window_size": 2, "window_mode": "tumbling", "metrics": "ttft", "percentiles": "50"} + ) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + now = time.time() + for i in range(3): + logger._pending.append(_make_record(f"req-{i}", now, i * 0.5)) + + logger._process_pending() + + # After 3 records with window_size=2: first 2 fill window then clear, 3rd starts fresh + assert len(logger._window) == 1 + logger.shutdown() + + +def test_compute_rolling_stats_empty_window(tmp_path): + """_compute_rolling_stats with empty window returns minimal result.""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "metrics": "all", "percentiles": "50"}) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + result = logger._compute_rolling_stats() + assert result["completed"] == 0 + logger.shutdown() + + +def test_compute_rolling_stats_single_record(tmp_path): + """Single record: no throughput, no tpot (needs output_len > 1 check).""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "metrics": "all", "percentiles": "50,99"}) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + now = time.time() + # output_len=1 means tpot and decode_speed won't be computed + logger._window.append( + CompletedRequestRecord( + request_id="r1", + completion_time=now, + arrival_time=now - 0.05, + inference_start_time=now - 0.04, + first_token_time=now - 0.02, + last_token_time=now, + input_len=100, + output_len=1, + itl_samples=[], + ) + ) + + result = logger._compute_rolling_stats() + assert result["completed"] == 1 + assert "request_throughput" not in result # duration=0 for single record + assert result["ttft_ms"]["mean"] > 0 + assert result["tpot_ms"] == {} # no tpot with output_len=1 + assert result["s_itl_ms"] == {} # no itl samples + logger.shutdown() + + +def test_compute_rolling_stats_multiple_records(tmp_path): + """Multiple records: throughput and all metrics computed.""" + config = BenchmarkMetricsConfig({"enable": True, "window_size": 0, "metrics": "all", "percentiles": "50,95"}) + logger = BenchmarkMetricsLogger(config=config, log_dir=str(tmp_path), dp_rank=0) + + now = time.time() + for i in range(3): + logger._window.append(_make_record(f"req-{i}", now, i * 0.5)) + + result = logger._compute_rolling_stats() + assert result["completed"] == 3 + assert result["request_throughput"] > 0 + assert result["output_throughput"] > 0 + assert result["total_throughput"] > 0 + assert result["ttft_ms"]["mean"] > 0 + assert result["s_ttft_ms"]["mean"] > 0 + assert result["tpot_ms"]["mean"] > 0 + assert result["s_itl_ms"]["mean"] > 0 + assert result["e2el_ms"]["mean"] > 0 + assert result["s_e2el_ms"]["mean"] > 0 + assert result["s_decode"]["mean"] > 0 + assert "p50" in result["ttft_ms"] + assert "p95" in result["ttft_ms"] + logger.shutdown() + + +def test_stats_with_float_percentile(): + """Percentile key uses float format when not integer.""" + stats = BenchmarkMetricsLogger._stats([1.0, 2.0, 3.0], [99.9]) + assert "p99.9" in stats diff --git a/tests/metrics/test_new_metrics.py b/tests/metrics/test_new_metrics.py index 030acaf4299..f650d6d7d7c 100644 --- a/tests/metrics/test_new_metrics.py +++ b/tests/metrics/test_new_metrics.py @@ -54,6 +54,8 @@ def test_cache_metrics_update_history(self, mock_main_process_metrics): def setUp(self): """为 TokenProcessor 测试设置通用的 mock 对象。""" self.mock_cfg = MagicMock() + self.mock_cfg.parallel_config.local_data_parallel_id = 0 + self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] self.mock_cached_generated_tokens = MagicMock() self.mock_engine_worker_queue = MagicMock() self.mock_split_connector = MagicMock() diff --git a/tests/model_executor/test_ep.py b/tests/model_executor/test_ep.py index 373e8899396..b099c7ad57e 100644 --- a/tests/model_executor/test_ep.py +++ b/tests/model_executor/test_ep.py @@ -419,6 +419,7 @@ def fake_get_moe_scores(*_args, **_kwargs): routed_scaling_factor=1.0, gate_correction_bias=None, renormalize=False, + fd_config=SimpleNamespace(scheduler_config=SimpleNamespace(enable_moe_scores_elementwise_fuse=False)), ) gate_out = paddle.randn([1, 4], dtype="float32") diff --git a/tests/model_executor/test_linear.py b/tests/model_executor/test_linear.py index 13f2bbe245e..aba98479303 100644 --- a/tests/model_executor/test_linear.py +++ b/tests/model_executor/test_linear.py @@ -58,6 +58,7 @@ def make_fd_config( expert_parallel_size=1, tp_group=None, use_sequence_parallel_moe=use_sequence_parallel_moe, + enable_flashinfer_allreduce_fusion=False, ), scheduler_config=SimpleNamespace(splitwise_role=splitwise_role, max_num_seqs=1), load_config=SimpleNamespace( diff --git a/tests/model_executor/test_model_executor_utils.py b/tests/model_executor/test_model_executor_utils.py index 98cba5c3302..701be987251 100644 --- a/tests/model_executor/test_model_executor_utils.py +++ b/tests/model_executor/test_model_executor_utils.py @@ -13,11 +13,16 @@ # limitations under the License. import unittest +import unittest.mock +from types import SimpleNamespace + +import paddle from fastdeploy.model_executor.utils import ( BitMaskTracker, TensorTracker, WeightsMapper, + process_weight_transpose, remap_weight_keys, set_weight_attrs, slice_fn, @@ -157,6 +162,132 @@ class Param: set_weight_attrs(p, None) # should not raise +class TestProcessWeightTranspose(unittest.TestCase): + def _make_layer(self, shape, dynamic_load_weight=True, load_strategy="rsync", rsync_config=None): + class Layer(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.fd_config = SimpleNamespace( + load_config=SimpleNamespace( + dynamic_load_weight=dynamic_load_weight, + load_strategy=load_strategy, + rsync_config=rsync_config or {}, + ), + model_config=SimpleNamespace(enable_cache=False), + ) + self.weight = self.create_parameter( + shape=shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + is_bias=False, + ) + + return Layer() + + def test_gdr_dynamic_transpose_preserves_loading_attrs_for_future_updates(self): + def loader(): + return None + + layer = self._make_layer([8, 4]) + layer.weight.output_dim = True + layer.weight.weight_need_transpose = False + layer.weight.weight_loader = loader + layer.weight.is_distributed = True + layer.weight.split_axis = 1 + layer.weight.tensor_track = object() + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertEqual(layer.weight.shape, [4, 8]) + self.assertFalse(layer.weight.output_dim) + self.assertTrue(layer.weight.weight_need_transpose) + self.assertIs(layer.weight.weight_loader, loader) + self.assertTrue(layer.weight.is_distributed) + self.assertEqual(layer.weight.split_axis, 0) + self.assertFalse(hasattr(layer.weight, "tensor_track")) + + def test_gpu_direct_dynamic_transpose_preserves_loading_attrs(self): + layer = self._make_layer([8, 4]) + layer.weight.output_dim = True + layer.weight.split_axis = 1 + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertFalse(layer.weight.output_dim) + self.assertTrue(layer.weight.weight_need_transpose) + self.assertEqual(layer.weight.split_axis, 0) + + def test_rdma_dynamic_transpose_does_not_preserve_loading_attrs(self): + layer = self._make_layer([8, 4]) + layer.weight.output_dim = True + layer.weight.split_axis = 1 + + process_weight_transpose(layer, "weight") + + self.assertEqual(layer.weight.shape, [4, 8]) + self.assertFalse(hasattr(layer.weight, "output_dim")) + self.assertFalse(hasattr(layer.weight, "weight_need_transpose")) + self.assertFalse(hasattr(layer.weight, "split_axis")) + + def test_ct_ipc_dynamic_transpose_preserves_loading_attrs(self): + layer = self._make_layer([8, 4], load_strategy="ipc") + layer.weight.output_dim = True + layer.weight.split_axis = 1 + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertFalse(layer.weight.output_dim) + self.assertTrue(layer.weight.weight_need_transpose) + self.assertEqual(layer.weight.split_axis, 0) + + def test_gdr_transpose_preserves_loading_attrs_for_3d_weight(self): + layer = self._make_layer([2, 8, 4]) + layer.weight.output_dim = False + layer.weight.split_axis = 1 + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertEqual(layer.weight.shape, [2, 4, 8]) + self.assertTrue(layer.weight.output_dim) + self.assertTrue(layer.weight.weight_need_transpose) + self.assertEqual(layer.weight.split_axis, 2) + + def test_gdr_transpose_clears_weight_need_transpose_for_torch_format(self): + """Production scenario: torch format sets weight_need_transpose=True. + After transpose, param is in HF layout so no transpose needed on reload.""" + + def loader(): + return None + + layer = self._make_layer([8, 4]) + layer.weight.output_dim = True + layer.weight.weight_need_transpose = True + layer.weight.weight_loader = loader + layer.weight.split_axis = 1 + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertEqual(layer.weight.shape, [4, 8]) + self.assertFalse(layer.weight.output_dim) + self.assertFalse(layer.weight.weight_need_transpose) + self.assertIs(layer.weight.weight_loader, loader) + self.assertEqual(layer.weight.split_axis, 0) + + def test_gdr_transpose_preserves_none_output_dim(self): + layer = self._make_layer([8, 4]) + layer.weight.output_dim = None + + with unittest.mock.patch.dict("os.environ", {"FD_USE_GDR_CHECKPOINT_TRANSFER": "1"}): + process_weight_transpose(layer, "weight") + + self.assertIsNone(layer.weight.output_dim) + + class TestSliceFn(unittest.TestCase): def test_1d_slice(self): import numpy as np diff --git a/tests/model_executor/test_thinking_budget.py b/tests/model_executor/test_thinking_budget.py index 139b6859951..4cc5a1563bd 100644 --- a/tests/model_executor/test_thinking_budget.py +++ b/tests/model_executor/test_thinking_budget.py @@ -111,7 +111,7 @@ def setUp(self): self._fdconfig_patches = [ patch.object(FDConfig, "read_from_config", return_value=None), patch.object(FDConfig, "postprocess", return_value=None), - patch.object(FDConfig, "init_cache_info", return_value=None), + patch.object(FDConfig, "init_pd_info", return_value=None), patch.object(FDConfig, "check", return_value=None), ] for patcher in self._fdconfig_patches: diff --git a/tests/model_loader/test_load_ernie_vl.py b/tests/model_loader/test_load_ernie_vl.py index abbdeb542f5..129c6076533 100644 --- a/tests/model_loader/test_load_ernie_vl.py +++ b/tests/model_loader/test_load_ernie_vl.py @@ -15,7 +15,6 @@ import json import os import signal -import socket import subprocess import sys import time @@ -28,96 +27,14 @@ if project_root not in sys.path: sys.path.insert(0, project_root) -# Read ports from environment variables; use default values if not set -FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) -FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8133)) -FD_METRICS_PORT = int(os.getenv("FD_METRICS_PORT", 8233)) -FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333)) - -# List of ports to clean before and after tests -PORTS_TO_CLEAN = [FD_API_PORT, FD_ENGINE_QUEUE_PORT, FD_METRICS_PORT, FD_CACHE_QUEUE_PORT] - - -def is_port_open(host: str, port: int, timeout=1.0): - """ - Check if a TCP port is open on the given host. - Returns True if connection succeeds, False otherwise. - """ - try: - with socket.create_connection((host, port), timeout): - return True - except Exception: - return False - - -def kill_process_on_port(port: int): - """ - Kill processes that are listening on the given port. - Uses multiple methods to ensure thorough cleanup. - """ - current_pid = os.getpid() - parent_pid = os.getppid() - - # Method 1: Use lsof to find processes - try: - output = subprocess.check_output(f"lsof -i:{port} -t", shell=True).decode().strip() - for pid in output.splitlines(): - pid = int(pid) - if pid in (current_pid, parent_pid): - print(f"Skip killing current process (pid={pid}) on port {port}") - continue - try: - # First try SIGTERM for graceful shutdown - os.kill(pid, signal.SIGTERM) - time.sleep(1) - # Then SIGKILL if still running - os.kill(pid, signal.SIGKILL) - print(f"Killed process on port {port}, pid={pid}") - except ProcessLookupError: - pass # Process already terminated - except subprocess.CalledProcessError: - pass - - # Method 2: Use netstat and fuser as backup - try: - # Find processes using netstat and awk - cmd = f"netstat -tulpn 2>/dev/null | grep :{port} | awk '{{print $7}}' | cut -d'/' -f1" - output = subprocess.check_output(cmd, shell=True).decode().strip() - for pid in output.splitlines(): - if pid and pid.isdigit(): - pid = int(pid) - if pid in (current_pid, parent_pid): - continue - try: - os.kill(pid, signal.SIGKILL) - print(f"Killed process (netstat) on port {port}, pid={pid}") - except ProcessLookupError: - pass - except (subprocess.CalledProcessError, FileNotFoundError): - pass - - # Method 3: Use fuser if available - try: - subprocess.run(f"fuser -k {port}/tcp", shell=True, timeout=5) - except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): - pass - - -def clean_ports(): - """ - Kill all processes occupying the ports listed in PORTS_TO_CLEAN. - """ - print(f"Cleaning ports: {PORTS_TO_CLEAN}") - for port in PORTS_TO_CLEAN: - kill_process_on_port(port) - - # Double check and retry if ports are still in use - time.sleep(2) - for port in PORTS_TO_CLEAN: - if is_port_open("127.0.0.1", port, timeout=0.1): - print(f"Port {port} still in use, retrying cleanup...") - kill_process_on_port(port) - time.sleep(1) +from e2e.utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean_ports, + is_port_open, +) @pytest.fixture(scope="session", autouse=True) @@ -184,8 +101,8 @@ def setup_and_run_server(): start_new_session=True, # Enables killing full group via os.killpg ) - # Wait up to 10 minutes for API server to be ready - for _ in range(10 * 60): + # Wait up to 5 minutes for API server to be ready + for _ in range(300): if is_port_open("127.0.0.1", FD_API_PORT): print(f"API server is up on port {FD_API_PORT}") break diff --git a/tests/operators/attention/test_decode_unified_attention_c16.py b/tests/operators/attention/test_decode_unified_attention_c16.py new file mode 100644 index 00000000000..0d17d17ccd6 --- /dev/null +++ b/tests/operators/attention/test_decode_unified_attention_c16.py @@ -0,0 +1,868 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention as append_attention_op, +) +from fastdeploy.model_executor.layers.attention.ops import ( + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + emb = paddle.unsqueeze(emb, 2) + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def _apply_rope(self, rotary_emb, q, k, start_pos=0): + seq, head_dim = q.shape[2], q.shape[3] + cos, sin = paddle.chunk(rotary_emb, 2, axis=0) + cos = cos[:, :, start_pos : start_pos + seq, ...] + sin = sin[:, :, start_pos : start_pos + seq, ...] + cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim]) + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim]) + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + + query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos)) + key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos)) + return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype) + + +def naive_attention_impl(query, key, value, cache_k=None, cache_v=None, mask=None, scale=1.0): + batch = query.shape[0] + heads = query.shape[1] + seq_len = query.shape[2] + head_dim = query.shape[3] + kv_head = key.shape[1] + + key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) + key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) + key = key.reshape([batch, heads, seq_len, head_dim]) + + if cache_k is not None: + cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim]) + cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1]) + cache_k = cache_k.reshape([batch, heads, -1, head_dim]) + key = paddle.concat([cache_k, key], axis=2) + + value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) + value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) + value = value.reshape([batch, heads, seq_len, head_dim]) + + if cache_v is not None: + cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim]) + cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1]) + cache_v = cache_v.reshape([batch, heads, -1, head_dim]) + value = paddle.concat([cache_v, value], axis=2) + + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + if mask is not None: + attention = attention + mask + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value) + return result + + +def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len): + """Read K/V from paged cache and return as [batch, num_head, seq_len, dim_head].""" + _, num_head, blocksize, dim_head = cache_k.shape + out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype) + out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype) + for i in range(bsz): + for j in range(cache_seq_len): + out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] + out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] + return out_cache_k, out_cache_v + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def remove_padding(seq_lens, cu_seq_lens, inputs, token_num): + bsz, num_head, seq_len, head_dim = inputs.shape + output = paddle.zeros(shape=[token_num, num_head * head_dim], dtype=inputs.dtype) + inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1]) + for i in range(bsz): + seq_len_now = seq_lens[i] + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :] + return output + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeUnifiedAttentionC16(unittest.TestCase): + """Base test class for decode append attention with cache_quant_type='none' (fp16/bf16 KV cache). + + Uses append_attention for prefill (verified correct by test_append_attention_c16.py) + and then tests decode_unified_attention (new split ops) against the same naive reference. + + Subclasses override setUp to vary batch_size, max_tokens_per_batch, dtype, etc. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + + # Use small seq_len for fast testing; can increase later + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def init_tensor(self): + self.rope = RopeEmbedding(self.use_neox_rotary_style) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + + # Encoder phase: prefill with seq_len tokens + self.enc_q, self.enc_k, self.enc_v, self.enc_qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.seq_len, + self.head_dim, + self.place, + self.dtype, + ) + + # Decoder phase: max_tokens_per_batch decode tokens + self.dec_q, self.dec_k, self.dec_v, self.dec_qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + + def _get_block_shape_buffers(self, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time): + max_num_block_dec = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + return { + "decoder_batch_ids": decoder_batch_ids, + "decoder_tile_ids_per_batch": decoder_tile_ids_per_batch, + "decoder_num_blocks_cpu": decoder_num_blocks_cpu, + "encoder_batch_ids": encoder_batch_ids, + "encoder_tile_ids_per_batch": encoder_tile_ids_per_batch, + "encoder_num_blocks_cpu": encoder_num_blocks_cpu, + "kv_batch_ids": kv_batch_ids, + "kv_tile_ids_per_batch": kv_tile_ids_per_batch, + "kv_num_blocks_x_cpu": kv_num_blocks_x_cpu, + "max_len_tensor_cpu": max_len_tensor_cpu, + } + + def run_append_attention( + self, + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run append_attention op.""" + buffers = self._get_block_shape_buffers(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time) + + qkv_copy = copy.deepcopy(qkv) + cache_k_copy = copy.deepcopy(cache_k) + cache_v_copy = copy.deepcopy(cache_v) + + out = append_attention_op( + qkv_copy, + cache_k_copy, + cache_v_copy, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffers["encoder_batch_ids"], + buffers["encoder_tile_ids_per_batch"], + buffers["encoder_num_blocks_cpu"], + buffers["kv_batch_ids"], + buffers["kv_tile_ids_per_batch"], + buffers["kv_num_blocks_x_cpu"], + buffers["decoder_batch_ids"], + buffers["decoder_tile_ids_per_batch"], + buffers["decoder_num_blocks_cpu"], + buffers["max_len_tensor_cpu"], + self.rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + None, # q_norm_weight + None, # k_norm_weight + None, # sinks + self.rms_norm_eps, + "bf16", + self.cache_quant_type, + self.use_neox_rotary_style, + self.rope_3d, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + -1, + 64, + 16, + 1024, + self.max_model_len, + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, + self.max_tokens_per_batch > 1, # speculate_decoder + ) + return out, cache_k_copy, cache_v_copy + + def _build_decode_buffer(self): + """Build buffer for new split decode ops.""" + buffer = {} + min_chunk_size = 512 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + return buffer + + def _run_decode_unified_attention( + self, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run config_for_attention + decoder_write_cache_with_rope + decode_unified_attention.""" + buffer = self._build_decode_buffer() + + config_for_attention( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + + dec_cache_k = copy.deepcopy(cache_k) + dec_cache_v = copy.deepcopy(cache_v) + dec_qkv = copy.deepcopy(self.dec_qkv) + + decoder_write_cache_with_rope( + dec_qkv, + dec_cache_k, + dec_cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffer["max_len_tensor_cpu"], + self.rotary_embs, + None, # qkv_bias + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + None, # q_norm_weight + None, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + self.use_neox_rotary_style, + self.rope_3d, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + + out = decode_unified_attention( + dec_qkv, + dec_cache_k, + dec_cache_v, + buffer["tmp_workspace"], + buffer["tmp_m"], + buffer["tmp_d"], + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + None, # attn_mask + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # mask_offset + None, # sinks + paddle.empty([dec_qkv.shape[0], self.q_num_head * self.head_dim], dtype=dec_qkv.dtype), # fmha_out + self.cache_quant_type, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + ) + return out, dec_cache_k, dec_cache_v + + def do_prefill_with_append_attention(self): + """Prefill using append_attention. Returns cache_k, cache_v after prefill.""" + seq_lens_encoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + seq_lens_decoder = paddle.to_tensor([0] * self.batch_size, "int32") + seq_lens_this_time = copy.deepcopy(seq_lens_encoder) + + batch_id_per_token, cu_seqlens_q, _ = get_padding_offset(self.batch_size, seq_lens_this_time) + + _, cache_k, cache_v = self.run_append_attention( + self.enc_qkv, + self.cache_k, + self.cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ) + return cache_k, cache_v + + def compute_naive_decode_ref(self, cache_k, cache_v): + """Compute naive reference for decode step using cache from paged cache.""" + # Read K/V from paged cache + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + cache_k, cache_v, self.batch_size, self.block_tables, self.seq_len + ) + + # Only use the first decode token (seq_lens_this_time=1 per batch) + dec_q = self.dec_q[:, :, :1, :] + dec_k = self.dec_k[:, :, :1, :] + dec_v = self.dec_v[:, :, :1, :] + + # Apply RoPE to decode Q/K at position seq_len + dec_q_rope, dec_k_rope = self.rope._apply_rope(self.rotary_embs, dec_q, dec_k, start_pos=self.seq_len) + + # Compute naive attention + out_ref = naive_attention_impl( + dec_q_rope, + dec_k_rope, + dec_v, + cache_k=naive_cache_k, + cache_v=naive_cache_v, + scale=self.softmax_scale, + ) + + dec_seq_lens_this_time = paddle.to_tensor([1] * self.batch_size, "int32") + dec_token_num = self.batch_size + _, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + out_ref = remove_padding(dec_seq_lens_this_time, dec_cu_seqlens_q, out_ref, dec_token_num) + return out_ref + + def test_naive_vs_append_attention_decode(self): + """Test: prefill with append_attention, then decode with append_attention. Compare to naive.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Naive reference for decode + out_ref = self.compute_naive_decode_ref(cache_k, cache_v) + + # Step 3: Decode with append_attention + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_dec, _, _ = self.run_append_attention( + self.dec_qkv, + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_ref_f = out_ref.astype("float32").numpy() + out_dec_f = out_dec.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_dec_f = out_dec_f[:dec_token_num] + + np.testing.assert_allclose( + out_dec_f, + out_ref_f, + rtol=1e-02, + atol=1e-02, + err_msg="append_attention decode output doesn't match naive reference", + ) + + def test_naive_vs_decode_unified_attention(self): + """Test: prefill with append_attention, then decode with new split decode ops.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Naive reference for decode + out_ref = self.compute_naive_decode_ref(cache_k, cache_v) + + # Step 3: Decode with new split ops + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out, _, _ = self._run_decode_unified_attention( + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_ref_f = out_ref.astype("float32").numpy() + out_decode_f = out.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_decode_f = out_decode_f[:dec_token_num] + + np.testing.assert_allclose( + out_decode_f, + out_ref_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention output doesn't match naive reference", + ) + + def test_append_vs_decode_unified_attention(self): + """Test: append_attention decode vs new split decode ops should produce same result.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Decode with append_attention + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_append, _, _ = self.run_append_attention( + self.dec_qkv, + copy.deepcopy(cache_k), + copy.deepcopy(cache_v), + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + # Step 3: Decode with new split ops + out_decode, _, _ = self._run_decode_unified_attention( + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_append_f = out_append.astype("float32").numpy() + out_decode_f = out_decode.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_append_f = out_append_f[:dec_token_num] + out_decode_f = out_decode_f[:dec_token_num] + + np.testing.assert_allclose( + out_decode_f, + out_append_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention doesn't match append_attention decode", + ) + + +class TestDecodeUnifiedAttentionC16Speculate(TestDecodeUnifiedAttentionC16): + """Test with speculate decode: max_tokens_per_batch=2. + + When max_tokens_per_batch > 1, naive ref only computes 1 token while ops + compute multiple tokens. So naive comparison tests are skipped; only + append_attention vs decode_unified_attention comparison is kept. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 2 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def test_naive_vs_append_attention_decode(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + def test_naive_vs_decode_unified_attention(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + +class TestDecodeUnifiedAttentionC16MultiBatch(TestDecodeUnifiedAttentionC16): + """Test with multiple batches.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 4 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16MultiHead(TestDecodeUnifiedAttentionC16): + """Test with multiple KV heads (GQA).""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 2 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16FP16(TestDecodeUnifiedAttentionC16): + """Test with float16 dtype.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "float16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16NoCausal(TestDecodeUnifiedAttentionC16): + """Test with causal=False.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = False + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16MultiBatchSpeculate(TestDecodeUnifiedAttentionC16): + """Test with multi-batch + speculate decode. + + When max_tokens_per_batch > 1, the naive reference only computes 1 token + while ops compute multiple tokens. So we only compare append_attention vs + decode_unified_attention (both should produce same result), and skip the + naive comparison tests. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 4 + self.max_tokens_per_batch = 2 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def test_naive_vs_append_attention_decode(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + def test_naive_vs_decode_unified_attention(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/attention/test_decode_unified_attention_c8.py b/tests/operators/attention/test_decode_unified_attention_c8.py new file mode 100644 index 00000000000..d5ec0e5354c --- /dev/null +++ b/tests/operators/attention/test_decode_unified_attention_c8.py @@ -0,0 +1,921 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention, + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, + gqa_rope_write_cache, + pre_cache_len_concat, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, D/2] + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + # shape: [B, S, 1, D/2] + emb = paddle.unsqueeze(emb, 2) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeUnifiedAttention(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + def init_tensor(self): + # seq_lens + if self.seq_len_dec is None: + self.seq_lens_dec = [ + self.cache_len, + ] * self.batch_size + else: + self.batch_size = len(self.seq_lens_dec) + self.seq_lens_decoder = paddle.to_tensor( + self.seq_lens_dec, + "int32", + ) + if self.seq_lens_this_time is None: + self.seq_lens_this_time = [ + self.max_tokens_per_batch, + ] * self.batch_size + self.token_num = sum(self.seq_lens_this_time) + self.seq_lens_this_time = paddle.to_tensor(self.seq_lens_this_time, "int32") + + self.seq_lens_enc = [0] * self.batch_size + + self.seq_lens_encoder = paddle.to_tensor( + self.seq_lens_enc, + "int32", + ) + + # self.qkv = paddle.rand([self.token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim], dtype=self.dtype) + self.q, self.k, self.v, self.qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + self.qkv = paddle.to_tensor(self.qkv, dtype=self.dtype) + + # qk_norm + self.q_norm_weight = None + self.k_norm_weight = None + if self.use_qk_norm: + q_norm_weight_np = np.random.random([self.head_dim]) / 10 + k_norm_weight_np = np.random.random([self.head_dim]) / 10 + self.q_norm_weight = paddle.to_tensor(q_norm_weight_np, dtype="float32") + self.k_norm_weight = paddle.to_tensor(k_norm_weight_np, dtype="float32") + + # rotary embedding + self.rope = RopeEmbedding(False) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache_kv && scale + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + + if self.cache_quant_type == "block_wise_fp8": + self.cache_scale_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_k_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_v_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_k_out_scale = None + self.cache_v_out_scale = None + else: + self.cache_k_scale = ( + self.quant_max_bound / self.k.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) + ).astype(self.dtype) + self.cache_v_scale = ( + self.quant_max_bound / self.v.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) + ).astype(self.dtype) + + self.cache_k_out_scale = ( + self.k.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) / self.quant_max_bound + ).astype(self.dtype) + self.cache_v_out_scale = ( + self.v.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) / self.quant_max_bound + ).astype(self.dtype) + + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + + ( + self.batch_id_per_token, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, self.seq_lens_this_time) + + # mask offset + self.mask_offset = None + if self.use_mask_offset: + self.mask_offset = paddle.full(self.batch_size * 2, 0, "int32") + for i in range(self.batch_size): + self.mask_offset[i * 2] = 0 + self.mask_offset[i * 2 + 1] = self.seq_lens_dec[i] + 1 + + # buffer + self.buffer = {} + min_chunk_size = 512 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + self.group_size = self.q_num_head // self.kv_num_head + q_tile_size = 16 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + self.buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + # block_indices: Launched block's indices with 4 dimensions [batch_idx, kv_head_idx, chunk_idx, q_tile_idx] in decode append attention backend + self.buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + # num_blocks: Number of Launched blocks in decode append attention backend, researched by config_for_attention op + self.buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + # chunk_size: Chunk size for split kv cache in decode append attention backend, researched by config_for_attention op + self.buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + # tmp_workspace: Workspace tensor for temporary store the result before merging in decode append attention backend + self.buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + # tmp_m: Tmp_m tensor for temporary store the max value before merging in decode append attention backend + self.buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + # tmp_d: Tmp_d tensor for temporary store the exponential sum before merging in decode append attention backend + self.buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + + def append_attention_with_args( + self, + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run append_attention with explicit arguments.""" + # buffer + max_num_block_dec = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + out = append_attention( + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + max_len_tensor_cpu, + self.rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + self.q_norm_weight, + self.k_norm_weight, + None, # sinks + self.rms_norm_eps, + "bf16", + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + -1, + 64, + 16, + self.max_model_len, + 1024, + self.max_tokens_per_batch, + self.causal, + self.max_tokens_per_batch > 1, + self.sliding_window, + ) + return out, cache_k, cache_v + + def append_attention(self): + """Convenience wrapper using default self members.""" + return self.append_attention_with_args( + copy.deepcopy(self.qkv), + copy.deepcopy(self.cache_k), + copy.deepcopy(self.cache_v), + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + ) + + def decode_unified_attention(self): + paddle.disable_static() + + config_for_attention( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + # print(f"num_blocks: {self.buffer['num_blocks']}") + decoder_write_cache_with_rope( + self.qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["max_len_tensor_cpu"], + self.rotary_embs, # rotary_embs + None, # qkv_bias + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + self.q_norm_weight, # q_norm_weight + self.k_norm_weight, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + + out = decode_unified_attention( + self.qkv, + self.cache_k, + self.cache_v, + self.buffer["tmp_workspace"], + self.buffer["tmp_m"], + self.buffer["tmp_d"], + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], # set_max_lengths + None, # attn_mask + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # mask_offset + None, # sinks # sinks + paddle.empty([self.qkv.shape[0], self.q_num_head * self.head_dim], dtype=self.qkv.dtype), # fmha_out + self.cache_quant_type, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + self.sliding_window, + ) + return self.qkv, out + + def prefill(self): + # init seq_len + seq_lens_encoder = copy.deepcopy(self.seq_lens_decoder) + seq_lens_decoder = paddle.zeros([self.batch_size], dtype="int32") + seq_lens_this_time = seq_lens_encoder + token_num = seq_lens_this_time.sum().item() + qkv_np = np.random.random([token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim]) - 0.5 + qkv = paddle.to_tensor(qkv_np, dtype=self.dtype) + + ( + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(self.batch_size, seq_lens_this_time) + # buffer + decode_max_tile_size = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + ( + cu_seqlens_k, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + kv_token_num_cpu, + ) = pre_cache_len_concat( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + max_len_tensor_cpu[2], + self.block_size, + ) + q, k, v, _ = gqa_rope_write_cache( + qkv, + self.cache_k, + self.cache_v, + cu_seqlens_q, + cu_seqlens_k, + self.rotary_embs, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + self.block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + self.q_norm_weight, + self.k_norm_weight, + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + kv_token_num_cpu[0].item(), + self.max_model_len, + self.rms_norm_eps, + False, # use_neox_rotary_style + self.cache_quant_type, + self.rope_3d, + ) + + k = k.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + v = v.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + return k, v + + def test_all(self): + """Compare append_attention vs decode_unified_attention output for consistency.""" + # Step 1: Prefill - just write K/V to cache via gqa_rope_write_cache + self.prefill() + + # Step 2: Decode with append_attention (copy cache so it's not modified) + dec_seq_lens_encoder = paddle.zeros([self.batch_size], dtype="int32") + dec_seq_lens_decoder = copy.deepcopy(self.seq_lens_decoder) + + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, dtype="int32") + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_append_dec, _, _ = self.append_attention_with_args( + copy.deepcopy(self.qkv), + copy.deepcopy(self.cache_k), + copy.deepcopy(self.cache_v), + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + # Step 3: Decode with decode_unified_attention (uses self.cache_k/v directly) + _, out_decode = self.decode_unified_attention() + + # Step 4: Compare + out_append_f = out_append_dec.astype("float32").numpy() + out_decode_f = out_decode.astype("float32").numpy() + + np.testing.assert_allclose( + out_decode_f, + out_append_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention output doesn't match append_attention output", + ) + + +class TestDecodeUnifiedAttentionMultiBatch(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 60 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionSpeculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionMultiHead(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionMultiSpeculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 4 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionSpeculateBs128Mtp4(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 128 + self.max_tokens_per_batch = 4 + self.cache_len = 508 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 2048 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8MultiBatch(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 60 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8Speculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 4 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionQKNorm(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = True + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_build_sampling_params_logprob.py b/tests/operators/test_build_sampling_params_logprob.py new file mode 100644 index 00000000000..9eb5e8e3052 --- /dev/null +++ b/tests/operators/test_build_sampling_params_logprob.py @@ -0,0 +1,269 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Any, Dict + +import numpy as np +import paddle + +# --- Import ops (bypass fastdeploy.__init__) --- +try: + import os + import sys + + _fd_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + if _fd_root not in sys.path: + sys.path.insert(0, _fd_root) + from fastdeploy.import_ops import import_custom_ops + + _package = "fastdeploy.model_executor.ops.gpu" + import_custom_ops(_package, ".fastdeploy_ops", globals()) +except ImportError as e: + print(f"Import error: {e}") + raise + +CUDA_PLACE = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + + +# ============================================================ +# Layer 1: Helpers -- tensor creation / kernel invocation / output extraction +# ============================================================ + + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy dict -> paddle tensors on GPU. Scalar attrs are passed through.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + + +def run_kernel(paddle_inputs, inputs): + """Call build_sampling_params_logprob with paddle tensors + scalar attrs.""" + return build_sampling_params_logprob( # noqa: F821 + paddle_inputs["input_params"], + paddle_inputs["token_num_per_batch"], + inputs["token_num_output_cpu"], + ) + + +def get_outputs(result) -> Dict[str, np.ndarray]: + """Extract output tensor to numpy.""" + return {"output_params": result.numpy()} + + +# ============================================================ +# Layer 2: Input generation +# ============================================================ + + +def gen_inputs( + real_bsz=8, + max_tokens_per_batch=5, + dtype=np.float32, + seed=42, +) -> Dict[str, Any]: + """Generate randomized test inputs. + + Args: + real_bsz: number of batch items + max_tokens_per_batch: max token count per batch item + dtype: numpy dtype for input_params (np.float32, np.int32, np.bool_) + seed: random seed + """ + rng = np.random.default_rng(seed) + + # Random token counts per batch, allow zeros (empty slots) + token_num_per_batch = rng.integers(0, max_tokens_per_batch + 1, size=real_bsz).astype(np.int32) + token_num_output_cpu = int(token_num_per_batch.sum()) + + # Generate per-batch param values + if dtype == np.float32: + input_params = rng.uniform(0.0, 1.0, size=real_bsz).astype(np.float32) + elif dtype == np.int32: + input_params = rng.integers(0, 100, size=real_bsz).astype(np.int32) + elif dtype == np.bool_: + input_params = rng.choice([False, True], size=real_bsz) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + return { + "input_params": input_params, + "token_num_per_batch": token_num_per_batch, + "token_num_output_cpu": token_num_output_cpu, + } + + +# ============================================================ +# Layer 3: Reference implementation (pure Python/NumPy) +# ============================================================ + + +def reference_build_sampling_params_logprob(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Python reference -- must match CUDA kernel logic exactly. + + Kernel logic: + 1. Initialize output with safe defaults (bool->False, int32->1, float32->1.0) + 2. For each batch bi, fill output[start_offset..start_offset+cur_token_num-1] + with input_params[bi], where start_offset = sum(token_num_per_batch[0..bi-1]) + """ + input_params = inputs["input_params"].copy() + token_num_per_batch = inputs["token_num_per_batch"].copy() + token_num_output_cpu = inputs["token_num_output_cpu"] + real_bsz = len(input_params) + dtype = input_params.dtype + + # Initialize output with safe defaults (matching kernel behavior) + if dtype == np.bool_: + output_params = np.full(token_num_output_cpu, False, dtype=dtype) + elif dtype == np.int32: + output_params = np.full(token_num_output_cpu, 1, dtype=dtype) + elif dtype == np.float32: + output_params = np.full(token_num_output_cpu, 1.0, dtype=dtype) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + for bi in range(real_bsz): + start_offset = int(token_num_per_batch[:bi].sum()) + cur_token_num = int(token_num_per_batch[bi]) + if cur_token_num <= 0: + continue + val = input_params[bi] + for i in range(cur_token_num): + idx = start_offset + i + if idx < token_num_output_cpu: + output_params[idx] = val + + return {"output_params": output_params} + + +# ============================================================ +# Layer 4a: TEST_CONFIGS -- all pure-parameter test scenarios +# ============================================================ + +TEST_CONFIGS = [ + # --- basic coverage, float32 --- + {"name": "float32_small_batch", "real_bsz": 2, "max_tokens_per_batch": 3, "dtype": np.float32, "seed": 42}, + {"name": "float32_medium_batch", "real_bsz": 16, "max_tokens_per_batch": 8, "dtype": np.float32, "seed": 42}, + {"name": "float32_large_batch", "real_bsz": 64, "max_tokens_per_batch": 16, "dtype": np.float32, "seed": 42}, + # --- int32 dtype --- + {"name": "int32_small_batch", "real_bsz": 4, "max_tokens_per_batch": 5, "dtype": np.int32, "seed": 42}, + {"name": "int32_large_batch", "real_bsz": 32, "max_tokens_per_batch": 10, "dtype": np.int32, "seed": 42}, + # --- bool dtype --- + {"name": "bool_small_batch", "real_bsz": 4, "max_tokens_per_batch": 5, "dtype": np.bool_, "seed": 42}, + {"name": "bool_large_batch", "real_bsz": 32, "max_tokens_per_batch": 10, "dtype": np.bool_, "seed": 42}, + # --- edge cases --- + {"name": "single_batch_single_token", "real_bsz": 1, "max_tokens_per_batch": 1, "dtype": np.float32, "seed": 42}, + {"name": "single_batch_many_tokens", "real_bsz": 1, "max_tokens_per_batch": 64, "dtype": np.float32, "seed": 42}, + {"name": "many_batch_one_token", "real_bsz": 64, "max_tokens_per_batch": 1, "dtype": np.float32, "seed": 42}, +] + + +# ============================================================ +# Layer 4b: Test suite +# ============================================================ + + +class TestBuildSamplingParamLogprob(unittest.TestCase): + + # ------ shared helpers ------ + + def _run_and_get(self, inputs): + paddle_inputs = to_paddle_inputs(inputs) + result = run_kernel(paddle_inputs, inputs) + return get_outputs(result) + + def _check_all_outputs(self, inputs, outputs): + """Compare ALL output tensors against reference.""" + ref = reference_build_sampling_params_logprob(inputs) + np.testing.assert_array_equal(outputs["output_params"], ref["output_params"], err_msg="output_params mismatch") + + def _run_full_test(self, config): + inputs = gen_inputs(**config) + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + return outputs + + # ------ test cases ------ + + def test_configs(self): + """Run all TEST_CONFIGS via subTest (one subTest per config).""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + self._run_full_test(test_cfg) + + def test_all_zero_token_counts(self): + """All batch items have zero tokens -- output should be empty array.""" + inputs = gen_inputs(real_bsz=4, max_tokens_per_batch=1, dtype=np.float32, seed=42) + # Force all token counts to zero + inputs["token_num_per_batch"] = np.zeros(4, dtype=np.int32) + inputs["token_num_output_cpu"] = 0 + outputs = self._run_and_get(inputs) + self.assertEqual(outputs["output_params"].size, 0) + + def test_exact_golden_float32(self): + """Exact golden values for float32 -- hand-verified.""" + inputs = { + "input_params": np.array([0.5, 0.9, 0.1], dtype=np.float32), + "token_num_per_batch": np.array([2, 3, 1], dtype=np.int32), + "token_num_output_cpu": 6, + } + outputs = self._run_and_get(inputs) + expected = np.array([0.5, 0.5, 0.9, 0.9, 0.9, 0.1], dtype=np.float32) + np.testing.assert_array_equal(outputs["output_params"], expected) + + def test_exact_golden_int32(self): + """Exact golden values for int32 -- hand-verified.""" + inputs = { + "input_params": np.array([10, 20, 30], dtype=np.int32), + "token_num_per_batch": np.array([1, 2, 3], dtype=np.int32), + "token_num_output_cpu": 6, + } + outputs = self._run_and_get(inputs) + expected = np.array([10, 20, 20, 30, 30, 30], dtype=np.int32) + np.testing.assert_array_equal(outputs["output_params"], expected) + + def test_exact_golden_bool(self): + """Exact golden values for bool -- hand-verified.""" + inputs = { + "input_params": np.array([True, False, True], dtype=np.bool_), + "token_num_per_batch": np.array([3, 2, 1], dtype=np.int32), + "token_num_output_cpu": 6, + } + outputs = self._run_and_get(inputs) + expected = np.array([True, True, True, False, False, True], dtype=np.bool_) + np.testing.assert_array_equal(outputs["output_params"], expected) + + def test_mixed_with_empty_slots(self): + """Some batch items have zero tokens (empty slots).""" + inputs = { + "input_params": np.array([0.5, 0.9, 0.1, 0.7], dtype=np.float32), + "token_num_per_batch": np.array([2, 0, 3, 0], dtype=np.int32), + "token_num_output_cpu": 5, + } + outputs = self._run_and_get(inputs) + # bi=0: tokens 0,1 -> 0.5; bi=1: empty; bi=2: tokens 2,3,4 -> 0.1; bi=3: empty + expected = np.array([0.5, 0.5, 0.1, 0.1, 0.1], dtype=np.float32) + np.testing.assert_array_equal(outputs["output_params"], expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_fused_rotary_position_encoding.py b/tests/operators/test_fused_rotary_position_encoding.py index cbff608c7c4..8ab5fb2c2ca 100644 --- a/tests/operators/test_fused_rotary_position_encoding.py +++ b/tests/operators/test_fused_rotary_position_encoding.py @@ -116,9 +116,10 @@ def test_neox_mode(self): self._check_correctness(num_tokens=3, num_heads=2, num_kv_heads=2, head_size=8, rot_dim=8, is_neox=True) def test_large_num_tokens(self): - self._check_correctness(num_tokens=10, num_heads=2, num_kv_heads=2, head_size=4, rot_dim=4, is_neox=False) - - def test_exceed_max_tokens(self): + """ + 测试算子支持大量 tokens(超过 65535) + 算子使用 2D grid,理论上可支持 65535*65535 个 tokens + """ num_tokens, num_heads, head_size = 65537, 1, 4 num_kv_heads, rot_dim = 1, 4 query_np = np.random.rand(num_tokens, num_heads, head_size).astype("float32") @@ -126,8 +127,10 @@ def test_exceed_max_tokens(self): position_ids_np = np.arange(num_tokens, dtype="int32") cos_sin_cache_np = self._make_cos_sin_cache(num_tokens, rot_dim) - with self.assertRaises(Exception): - self._run_op(query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False) + # 不应该抛出异常,算子应该能处理大量 tokens + query_out, key_out = self._run_op( + query_np, key_np, position_ids_np, cos_sin_cache_np, head_size, is_neox=False + ) if __name__ == "__main__": diff --git a/tests/operators/test_get_position_ids_and_mask_encoder_batch.py b/tests/operators/test_get_position_ids_and_mask_encoder_batch.py index 41474b4726c..54b34850780 100644 --- a/tests/operators/test_get_position_ids_and_mask_encoder_batch.py +++ b/tests/operators/test_get_position_ids_and_mask_encoder_batch.py @@ -17,59 +17,47 @@ import numpy as np import paddle -from fastdeploy.model_executor.ops.gpu import get_position_ids_and_mask_encoder_batch +from fastdeploy.model_executor.ops.gpu import get_position_ids -class TestGetPositionIdsAndMaskEncoderBatch(unittest.TestCase): +class TestGetPositionIds(unittest.TestCase): def setUp(self): np.random.seed(42) paddle.set_device("gpu") def test_basic_functionality(self): # Test normal case with batch size 2 - seq_lens_encoder = paddle.to_tensor([3, 2], dtype="int32") - seq_lens_decoder = paddle.to_tensor([1, 2], dtype="int32") + seq_lens_encoder = paddle.to_tensor([1, 2], dtype="int32") + seq_lens_decoder = paddle.to_tensor([3, 2], dtype="int32") seq_lens_this_time = paddle.to_tensor([1, 2], dtype="int32") - total_len = int(seq_lens_encoder.numpy().sum() + seq_lens_this_time.numpy().sum()) + total_len = int(seq_lens_this_time.numpy().sum()) position_ids = paddle.zeros([total_len], dtype="int32") - mask_encoder_batch = paddle.zeros([total_len], dtype="int32") # Call the custom operator - get_position_ids_and_mask_encoder_batch( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch - ) + get_position_ids(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) - expected_position_ids = np.array([0, 1, 2, 1, 0, 1, 2, 3], dtype=np.int32) - - expected_mask = np.array([1, 1, 1, 0, 1, 1, 0, 0], dtype=np.int32) + expected_position_ids = np.array([3, 2, 3], dtype=np.int32) # Convert to numpy for comparison position_ids_np = position_ids.numpy() - mask_encoder_batch_np = mask_encoder_batch.numpy() # Assert equality np.testing.assert_array_equal(position_ids_np, expected_position_ids) - np.testing.assert_array_equal(mask_encoder_batch_np, expected_mask) def test_empty_decoder(self): # Test case where decoder length is 0 seq_lens_encoder = paddle.to_tensor([2], dtype="int32") seq_lens_decoder = paddle.to_tensor([0], dtype="int32") - seq_lens_this_time = paddle.to_tensor([0], dtype="int32") + seq_lens_this_time = paddle.to_tensor([2], dtype="int32") position_ids = paddle.zeros([2], dtype="int32") - mask_encoder_batch = paddle.zeros([2], dtype="int32") - get_position_ids_and_mask_encoder_batch( - seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids, mask_encoder_batch - ) + get_position_ids(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids) expected_position_ids = np.array([0, 1], dtype=np.int32) - expected_mask = np.array([1, 1], dtype=np.int32) np.testing.assert_array_equal(position_ids.numpy(), expected_position_ids) - np.testing.assert_array_equal(mask_encoder_batch.numpy(), expected_mask) if __name__ == "__main__": diff --git a/tests/operators/test_get_position_ids_and_slot_mapping.py b/tests/operators/test_get_position_ids_and_slot_mapping.py new file mode 100644 index 00000000000..22bf32a3323 --- /dev/null +++ b/tests/operators/test_get_position_ids_and_slot_mapping.py @@ -0,0 +1,345 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import ( + get_position_ids, + get_position_ids_and_slot_mapping, +) + + +class TestGetPositionIdsAndSlotMapping(unittest.TestCase): + """Test the fused get_position_ids_and_slot_mapping kernel. + + Variable meanings: + - seq_lens_encoder: 0 if decode stage, else prefill length in current step + - seq_lens_decoder: total context length (processed history, prefill + decode) + - seq_lens_this_time: tokens to process in current step + """ + + def setUp(self): + np.random.seed(42) + paddle.set_device("gpu") + + def _compute_position_ids_and_slot_mapping_old( + self, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + block_size, + ): + """Old implementation for comparison.""" + sum_token_num = int(seq_lens_this_time.numpy().sum()) + + # get_position_ids expects int32, so use int32 and then cast to int64 + position_ids_int32 = paddle.zeros([sum_token_num], dtype="int32") + get_position_ids(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, position_ids_int32) + + block_idx = position_ids_int32 // block_size + block_ids = block_tables[batch_id_per_token[:sum_token_num], block_idx] + block_offset = position_ids_int32 % block_size + slot_mapping = (block_ids * block_size + block_offset).cast(paddle.int64) + + # Cast position_ids to int64 for comparison with new kernel + position_ids = position_ids_int32.cast(paddle.int64) + + return position_ids.numpy(), slot_mapping.numpy() + + def _compute_position_ids_and_slot_mapping_new( + self, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + block_size, + ): + """New fused kernel implementation.""" + sum_token_num = int(seq_lens_this_time.numpy().sum()) + # Create output buffers (int64 for kernel compatibility) + position_ids = paddle.zeros([sum_token_num], dtype="int64") + slot_mapping = paddle.zeros([sum_token_num], dtype="int64") + + # Kernel writes directly to buffers + get_position_ids_and_slot_mapping( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + position_ids, + slot_mapping, + block_size, + ) + + return position_ids.numpy(), slot_mapping.numpy() + + def _generate_batch_id_per_token(self, seq_lens_this_time, bsz): + """Generate batch_id_per_token based on seq_lens_this_time.""" + total_tokens = int(seq_lens_this_time.numpy().sum()) + batch_id_per_token = np.zeros([total_tokens], dtype=np.int32) + offset = 0 + for bid in range(bsz): + seq_len = int(seq_lens_this_time[bid].numpy()) + batch_id_per_token[offset : offset + seq_len] = bid + offset += seq_len + return paddle.to_tensor(batch_id_per_token, dtype="int32", place=paddle.CUDAPlace(0)) + + def _generate_block_tables(self, bsz, max_num_blocks): + """Generate block_tables with sequential block ids for reproducibility.""" + block_tables = np.arange(bsz * max_num_blocks, dtype=np.int32).reshape(bsz, max_num_blocks) + return paddle.to_tensor(block_tables, dtype="int32", place=paddle.CUDAPlace(0)) + + def test_single_batch_decode(self): + """Test single batch in decode stage.""" + # Decode stage: already processed 10 tokens, now decode 1 more + seq_lens_encoder = paddle.to_tensor([0], dtype="int32") # decode stage + seq_lens_decoder = paddle.to_tensor([10], dtype="int32") # history length + seq_lens_this_time = paddle.to_tensor([1], dtype="int32") # current step + + batch_id_per_token = paddle.to_tensor([0], dtype="int32") + block_tables = self._generate_block_tables(1, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[10], slot_old=[10] (block_id=0, block_offset=10, slot=0*64+10=10) + # logger.info(f"test_single_batch_decode: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new, "position_ids mismatch") + np.testing.assert_array_equal(slot_old, slot_new, "slot_mapping mismatch") + + # Verify position_id starts from seq_lens_decoder + self.assertEqual(pos_new[0], 10) + + def test_single_batch_prefill(self): + """Test single batch in prefill stage.""" + # Prefill stage: no history, processing 5 tokens + seq_lens_encoder = paddle.to_tensor([5], dtype="int32") # prefill length + seq_lens_decoder = paddle.to_tensor([0], dtype="int32") # no history + seq_lens_this_time = paddle.to_tensor([5], dtype="int32") # current step + + batch_id_per_token = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32") + block_tables = self._generate_block_tables(1, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[0,1,2,3,4], slot_old=[0,1,2,3,4] (all in block 0) + # logger.info(f"test_single_batch_prefill: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new, "position_ids mismatch") + np.testing.assert_array_equal(slot_old, slot_new, "slot_mapping mismatch") + + # Verify position_ids start from 0 + np.testing.assert_array_equal(pos_new, np.array([0, 1, 2, 3, 4])) + + def test_multiple_batches_decode(self): + """Test multiple batches all in decode stage.""" + # Batch 0: history 10, now 1 + # Batch 1: history 20, now 2 + seq_lens_encoder = paddle.to_tensor([0, 0], dtype="int32") # both decode + seq_lens_decoder = paddle.to_tensor([10, 20], dtype="int32") # history lengths + seq_lens_this_time = paddle.to_tensor([1, 2], dtype="int32") # current step + + batch_id_per_token = self._generate_batch_id_per_token(seq_lens_this_time, 2) + block_tables = self._generate_block_tables(2, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[10,20,21] + # Batch 0: position_id=10, block_id=0, block_offset=10, slot=10 + # Batch 1: position_ids=[20,21], batch_id=1, block_tables[1][0]=100 + # slot[1]=100*64+20=6420, slot[2]=100*64+21=6421 + # logger.info(f"test_multiple_batches_decode: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new, "position_ids mismatch") + np.testing.assert_array_equal(slot_old, slot_new, "slot_mapping mismatch") + + # Batch 0: position_id = 10 + # Batch 1: position_ids = [20, 21] + np.testing.assert_array_equal(pos_new, np.array([10, 20, 21])) + + def test_different_block_sizes(self): + """Test with different block sizes.""" + for block_size in [1, 8, 16, 32, 64]: + with self.subTest(block_size=block_size): + seq_lens_encoder = paddle.to_tensor([0], dtype="int32") # decode + seq_lens_decoder = paddle.to_tensor([10], dtype="int32") # history + seq_lens_this_time = paddle.to_tensor([5], dtype="int32") # current + batch_id_per_token = paddle.to_tensor([0] * 5, dtype="int32") + block_tables = self._generate_block_tables(1, 100) + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + block_size, + ) + # Expected: pos_old=[10,11,12,13,14] + # For block_size=64: block_id=0, slot=[10,11,12,13,14] + # For block_size=16: block_id=0 for all (10-14<16), slot=[10,11,12,13,14] + # logger.info(f"test_different_block_sizes[{block_size}]: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + block_tables, + block_size, + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + def test_block_boundary_crossing(self): + """Test tokens crossing block boundaries.""" + # block_size=64, history=60, so position_ids will be [60, 61, 62, 63, 64] + # This crosses the block boundary (60-63 in block 0, 64 in block 1) + seq_lens_encoder = paddle.to_tensor([0], dtype="int32") # decode + seq_lens_decoder = paddle.to_tensor([60], dtype="int32") # history + seq_lens_this_time = paddle.to_tensor([5], dtype="int32") # current + batch_id_per_token = paddle.to_tensor([0, 0, 0, 0, 0], dtype="int32") + block_tables = self._generate_block_tables(1, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[60,61,62,63,64] + # position 60-63: block_id=0, block_offset=60-63, slot=60-63 + # position 64: block_id=1, block_offset=0, slot=64 + # logger.info(f"test_block_boundary_crossing: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + # Verify position_ids + np.testing.assert_array_equal(pos_new, np.array([60, 61, 62, 63, 64])) + + def test_large_batch(self): + """Test with larger batch size.""" + bsz = 16 + # All in decode stage + seq_lens_encoder = paddle.to_tensor([0] * bsz, dtype="int32") + seq_lens_decoder = paddle.to_tensor(np.random.randint(0, 100, size=bsz), dtype="int32") + seq_lens_this_time = paddle.to_tensor(np.random.randint(1, 5, size=bsz), dtype="int32") + + batch_id_per_token = self._generate_batch_id_per_token(seq_lens_this_time, bsz) + block_tables = self._generate_block_tables(bsz, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Too many tokens to list expected values + # logger.info(f"test_large_batch: shape pos_old={pos_old.shape}, slot_old={slot_old.shape}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + def test_empty_batch(self): + """Test with some batches having zero tokens this step.""" + # Batch 0: decode (1 token) + # Batch 1: skip (0 tokens) + # Batch 2: decode (2 tokens) + seq_lens_encoder = paddle.to_tensor([0, 0, 0], dtype="int32") # all decode + seq_lens_decoder = paddle.to_tensor([10, 20, 5], dtype="int32") # history + seq_lens_this_time = paddle.to_tensor([1, 0, 2], dtype="int32") # current + + batch_id_per_token = self._generate_batch_id_per_token(seq_lens_this_time, 3) + block_tables = self._generate_block_tables(3, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[10,5,6] + # Batch 0: position_id=10, batch_id=0, block_id=0, slot=10 + # Batch 2: position_ids=[5,6], batch_id=2, block_tables[2][0]=200 + # slot[1]=200*64+5=12805, slot[2]=200*64+6=12806 + # logger.info(f"test_empty_batch: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + # Batch 0: position_id = 10 + # Batch 1: skipped + # Batch 2: position_ids = [5, 6] + np.testing.assert_array_equal(pos_new, np.array([10, 5, 6])) + + def test_mtp_scenario(self): + """Test MTP scenario where seq_lens_this_time varies per batch.""" + # All in decode stage, different accepted tokens per batch + seq_lens_encoder = paddle.to_tensor([0, 0], dtype="int32") # decode + seq_lens_decoder = paddle.to_tensor([10, 20], dtype="int32") # history + # Batch 0: 2 accepted tokens, Batch 1: 1 accepted token + seq_lens_this_time = paddle.to_tensor([2, 1], dtype="int32") + + batch_id_per_token = self._generate_batch_id_per_token(seq_lens_this_time, 2) + block_tables = self._generate_block_tables(2, 100) + block_size = 64 + + pos_old, slot_old = self._compute_position_ids_and_slot_mapping_old( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + # Expected: pos_old=[10,11,20] + # Batch 0: position_ids=[10,11], batch_id=0, block_id=0, slot=[10,11] + # Batch 1: position_ids=[20], batch_id=1, block_tables[1][0]=100 + # slot[2]=100*64+20=6420 + # logger.info(f"test_mtp_scenario: pos_old={pos_old}, slot_old={slot_old}") + pos_new, slot_new = self._compute_position_ids_and_slot_mapping_new( + seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, batch_id_per_token, block_tables, block_size + ) + + np.testing.assert_array_equal(pos_old, pos_new) + np.testing.assert_array_equal(slot_old, slot_new) + + # Batch 0: position_ids = [10, 11] + # Batch 1: position_id = [20] + np.testing.assert_array_equal(pos_new, np.array([10, 11, 20])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_grouped_topk_op.py b/tests/operators/test_grouped_topk_op.py new file mode 100644 index 00000000000..1e76328eb93 --- /dev/null +++ b/tests/operators/test_grouped_topk_op.py @@ -0,0 +1,485 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the `grouped_topk` custom CUDA op (low-level interface). + +grouped_topk fuses sigmoid into the kernel and accepts raw logits directly, +unlike noaux_tc which requires Python-side sigmoid preprocessing. + +Algorithm: + 1. scores = sigmoid(gating_output) [fused inside kernel] + 2. scores_with_bias = scores + e_score_correction_bias + 3. group_scores = sum of top-2 biased expert scores per group + 4. Select top-topk_group groups + 5. Within selected groups select top-topk experts by biased score + 6. Gather unbiased sigmoid scores for selected experts as topk_values + 7. Optionally renormalize and scale by routed_scaling_factor + +Model configs covered: + DeepSeek-V3 / R1 num_experts=256, n_group=8, topk_group=4, topk=8, renorm=True, scale=2.5 + GLM-4.5-Air num_experts=128, n_group=1, topk_group=1, topk=8, renorm=True, scale=1.0 + Qwen3-30B-A3B num_experts=128, n_group=4, topk_group=2, topk=8, renorm=False, scale=1.0 + Kimi-K2 num_experts=384, n_group=8, topk_group=2, topk=8, renorm=False, scale=1.0 +""" + +import unittest + +import numpy as np +import paddle + +try: + from fastdeploy.model_executor.ops.gpu import grouped_topk + + _GROUPED_TOPK_AVAILABLE = True +except Exception: + _GROUPED_TOPK_AVAILABLE = False + + +def native_grouped_topk( + gating_output: paddle.Tensor, + e_score_correction_bias: paddle.Tensor, + n_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, +): + """Pure-Python reference that mirrors the grouped_topk kernel semantics. + + Args: + gating_output: raw logits, shape [num_tokens, num_experts] + e_score_correction_bias: bias added to sigmoid scores, shape [1, num_experts] or [num_experts] + n_group: number of expert groups + topk_group: number of groups selected per token + topk: number of experts selected per token + renormalize: whether to L1-normalise the selected weights + routed_scaling_factor: multiplicative scale applied after renorm + + Returns: + (scores_out, topk_values, topk_indices) + scores_out – sparse score tensor, shape [num_tokens, num_experts] + topk_values – weights for selected experts, shape [num_tokens, topk] + topk_indices – expert indices, shape [num_tokens, topk] (int64) + """ + num_tokens, num_experts = gating_output.shape + experts_per_group = num_experts // n_group + + scores = paddle.nn.functional.sigmoid(gating_output) + scores_with_bias = scores + e_score_correction_bias + + # Step 1: group scores = sum of top-2 biased scores per group + biased = scores_with_bias.reshape([num_tokens, n_group, experts_per_group]) + group_scores = biased.topk(min(2, experts_per_group), axis=-1)[0].sum(axis=-1) + + # Step 2: select top-topk_group groups + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] + group_mask = paddle.zeros_like(group_scores) + group_mask.put_along_axis_(group_idx, paddle.ones_like(group_idx, dtype=group_mask.dtype), axis=-1) + score_mask = ( + group_mask.unsqueeze(-1).expand([num_tokens, n_group, experts_per_group]).reshape([num_tokens, num_experts]) + ) + + # Step 3: select top-topk experts within selected groups (biased score) + tmp_scores = scores_with_bias.masked_fill(~score_mask.cast(paddle.bool), float("-inf")) + topk_indices = paddle.topk(tmp_scores, topk, axis=-1)[1] + + # Step 4: gather unbiased sigmoid scores + topk_values = paddle.take_along_axis(scores, topk_indices, axis=1) + + # Step 5: renormalize + scale + if renormalize: + topk_values = topk_values / (topk_values.sum(axis=-1, keepdim=True) + 1e-20) + if routed_scaling_factor != 1.0: + topk_values = topk_values * routed_scaling_factor + + scores_out = paddle.zeros_like(scores) + scores_out.put_along_axis_(topk_indices, topk_values, axis=1) + + return scores_out, topk_values, topk_indices.cast(paddle.int64) + + +@unittest.skipUnless(_GROUPED_TOPK_AVAILABLE, "grouped_topk custom op not available (not compiled)") +class TestGroupedTopkOp(unittest.TestCase): + """Tests for the grouped_topk custom CUDA op.""" + + ATOL = 1e-3 + RTOL = 1e-3 + + def setUp(self): + paddle.seed(42) + + # ------------------------------------------------------------------ + # Parametrised helper + # ------------------------------------------------------------------ + def _run_case( + self, + num_tokens: int, + num_experts: int, + n_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, + input_dtype=paddle.float32, + bias_scale: float = 0.1, + seed: int = 42, + ): + paddle.seed(seed) + gating = paddle.randn([num_tokens, num_experts], dtype=input_dtype) + bias = (paddle.rand([1, num_experts], dtype=paddle.float32) - 0.5) * bias_scale + + # Reference always runs in fp32 + gating_fp32 = gating.cast(paddle.float32) if input_dtype != paddle.float32 else gating + ref_scores, ref_tv, ref_ti = native_grouped_topk( + gating_fp32.clone(), + bias.clone(), + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + op_scores, op_tv, op_ti = grouped_topk( + gating.clone(), + bias.clone(), + n_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + label = ( + f"T={num_tokens}, E={num_experts}, n_group={n_group}, " + f"topk_group={topk_group}, topk={topk}, " + f"renorm={renormalize}, scale={routed_scaling_factor}, dtype={input_dtype}" + ) + + self.assertEqual(op_tv.shape, [num_tokens, topk], f"[{label}] topk_values shape") + self.assertEqual(op_ti.shape, [num_tokens, topk], f"[{label}] topk_indices shape") + self.assertEqual(op_ti.dtype, paddle.int64, f"[{label}] topk_indices dtype") + self.assertEqual(op_tv.dtype, paddle.float32, f"[{label}] topk_values dtype") + + # Compare set-level index match (position order not guaranteed) + ref_sorted = paddle.sort(ref_ti, axis=-1) + op_sorted = paddle.sort(op_ti, axis=-1) + if not paddle.equal_all(ref_sorted, op_sorted).item(): + n_diff = (ref_sorted != op_sorted).sum().item() + self.fail(f"[{label}] topk_indices set mismatch: {n_diff} positions differ") + + # Align values by expert index before comparing + ref_ord = paddle.argsort(ref_ti, axis=-1) + op_ord = paddle.argsort(op_ti, axis=-1) + ref_tv_s = paddle.take_along_axis(ref_tv, ref_ord, axis=-1) + op_tv_s = paddle.take_along_axis(op_tv, op_ord, axis=-1) + if not paddle.allclose(op_tv_s, ref_tv_s, atol=self.ATOL, rtol=self.RTOL).item(): + max_diff = (op_tv_s - ref_tv_s).abs().max().item() + self.fail(f"[{label}] topk_values max_diff={max_diff:.2e}") + + # ------------------------------------------------------------------ + # GLM-4.5-Air: n_experts=128, n_group=1, topk_group=1, topk=8, renorm=True + # ------------------------------------------------------------------ + def test_glm45air_T1(self): + self._run_case(1, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T32(self): + self._run_case(32, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T128(self): + self._run_case(128, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T512(self): + self._run_case(512, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T1024(self): + self._run_case(1024, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T4096(self): + self._run_case(4096, 128, 1, 1, 8, True, 1.0) + + def test_glm45air_T8192(self): + self._run_case(8192, 128, 1, 1, 8, True, 1.0) + + # ------------------------------------------------------------------ + # DeepSeek-V3 / R1: n_experts=256, n_group=8, topk_group=4, topk=8, + # renorm=True, scale=2.5 + # ------------------------------------------------------------------ + def test_deepseek_v3_T1(self): + self._run_case(1, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T32(self): + self._run_case(32, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T128(self): + self._run_case(128, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T512(self): + self._run_case(512, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T4096(self): + self._run_case(4096, 256, 8, 4, 8, True, 2.5) + + def test_deepseek_v3_T8192(self): + self._run_case(8192, 256, 8, 4, 8, True, 2.5) + + # ------------------------------------------------------------------ + # Qwen3-30B-A3B: n_experts=128, n_group=4, topk_group=2, topk=8, + # renorm=False + # ------------------------------------------------------------------ + def test_qwen3_30b_T1(self): + self._run_case(1, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T128(self): + self._run_case(128, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T512(self): + self._run_case(512, 128, 4, 2, 8, False, 1.0) + + def test_qwen3_30b_T4096(self): + self._run_case(4096, 128, 4, 2, 8, False, 1.0) + + # ------------------------------------------------------------------ + # Kimi-K2: n_experts=384, n_group=8, topk_group=2, topk=8, renorm=False + # ------------------------------------------------------------------ + def test_kimi_k2_T1(self): + self._run_case(1, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T128(self): + self._run_case(128, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T512(self): + self._run_case(512, 384, 8, 2, 8, False, 1.0) + + def test_kimi_k2_T4096(self): + self._run_case(4096, 384, 8, 2, 8, False, 1.0) + + # ------------------------------------------------------------------ + # bfloat16 input path: kernel should cast internally to fp32 + # ------------------------------------------------------------------ + def test_bf16_input_glm45air(self): + self._run_case(128, 128, 1, 1, 8, True, 1.0, input_dtype=paddle.bfloat16) + + def test_bf16_input_deepseek_v3(self): + self._run_case(128, 256, 8, 4, 8, True, 2.5, input_dtype=paddle.bfloat16) + + def test_bf16_input_qwen3_30b(self): + self._run_case(128, 128, 4, 2, 8, False, 1.0, input_dtype=paddle.bfloat16) + + # ------------------------------------------------------------------ + # Output shape and dtype sanity + # ------------------------------------------------------------------ + def test_output_shapes(self): + """Verify output shapes for various (T, E, topk) combinations.""" + cases = [ + (1, 128, 1, 1, 8), + (32, 256, 8, 4, 8), + (64, 384, 8, 2, 8), + ] + for T, E, ng, tkg, topk in cases: + gating = paddle.randn([T, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + self.assertEqual(tv.shape, [T, topk], f"T={T},E={E}: topk_values shape") + self.assertEqual(ti.shape, [T, topk], f"T={T},E={E}: topk_indices shape") + + def test_output_dtype_is_float32(self): + """topk_values must always be float32 regardless of input dtype.""" + for dtype in [paddle.float32, paddle.bfloat16]: + gating = paddle.randn([16, 128], dtype=dtype) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + self.assertEqual(tv.dtype, paddle.float32, f"dtype={dtype}: topk_values not float32") + self.assertEqual(ti.dtype, paddle.int64, f"dtype={dtype}: topk_indices not int64") + + # ------------------------------------------------------------------ + # Correctness invariants + # ------------------------------------------------------------------ + def test_topk_indices_in_valid_range(self): + """All selected expert indices must lie in [0, num_experts).""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8), (384, 8, 2, 8)]: + gating = paddle.randn([64, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, _, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + self.assertTrue((ti >= 0).all().item(), f"E={E}: negative index found") + self.assertTrue((ti < E).all().item(), f"E={E}: index >= num_experts") + + def test_no_duplicate_experts_per_token(self): + """Each token must select exactly topk distinct experts.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + gating = paddle.randn([32, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, _, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + for row in ti.numpy(): + self.assertEqual(len(set(row.tolist())), topk, f"E={E}: duplicate expert indices in row {row}") + + def test_topk_values_non_negative(self): + """Sigmoid output is in (0,1); routing weights must be >= 0.""" + gating = paddle.randn([64, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + self.assertTrue((tv >= 0).all().item(), "topk_values contains negative weights") + + def test_renormalized_weights_sum_to_one(self): + """With renormalize=True and scale=1.0, per-token weights sum ≈ 1.""" + num_tokens = 64 + gating = paddle.randn([num_tokens, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 1, 1, 8, True, 1.0) + row_sums = tv.sum(axis=-1).numpy() + np.testing.assert_allclose( + row_sums, + np.ones(num_tokens, dtype=np.float32), + atol=1e-3, + err_msg="Renormalized weights do not sum to 1 per token", + ) + + def test_scaled_weights_sum(self): + """With renormalize=True and scale=2.5, per-token weights sum ≈ 2.5.""" + num_tokens, scale = 64, 2.5 + gating = paddle.randn([num_tokens, 256], dtype=paddle.float32) + bias = paddle.zeros([1, 256], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, 8, 4, 8, True, scale) + row_sums = tv.sum(axis=-1).numpy() + np.testing.assert_allclose( + row_sums, + np.full(num_tokens, scale, dtype=np.float32), + atol=1e-2, + err_msg=f"Scaled weights do not sum to {scale} per token", + ) + + def test_no_renorm_weights_are_raw_sigmoid(self): + """With renormalize=False, topk_values must equal sigmoid(logits) at selected positions.""" + num_tokens, E = 32, 128 + gating = paddle.randn([num_tokens, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 8, False, 1.0) + expected = paddle.take_along_axis(paddle.nn.functional.sigmoid(gating), ti, axis=1) + np.testing.assert_allclose( + tv.numpy(), + expected.numpy(), + atol=1e-4, + err_msg="Without renorm, topk_values should equal sigmoid(gating) at selected positions", + ) + + def test_deterministic(self): + """Two identical calls must produce bit-for-bit identical outputs.""" + gating = paddle.randn([32, 256], dtype=paddle.float32) + bias = (paddle.rand([1, 256], dtype=paddle.float32) - 0.5) * 0.1 + _, tv1, ti1 = grouped_topk(gating.clone(), bias.clone(), 8, 4, 8, True, 2.5) + _, tv2, ti2 = grouped_topk(gating.clone(), bias.clone(), 8, 4, 8, True, 2.5) + self.assertTrue( + paddle.allclose(tv1, tv2, atol=0.0, rtol=0.0).item(), + "topk_values not deterministic across two identical calls", + ) + self.assertTrue( + paddle.equal_all(ti1, ti2).item(), + "topk_indices not deterministic across two identical calls", + ) + + def test_zero_bias(self): + """All-zero bias: biased == unbiased; reference and op must agree.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + paddle.seed(16) + gating = paddle.randn([32, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, ref_tv, ref_ti = native_grouped_topk(gating.clone(), bias, ng, tkg, topk, True, 1.0) + _, op_tv, op_ti = grouped_topk(gating.clone(), bias, ng, tkg, topk, True, 1.0) + ref_s = paddle.sort(ref_ti, axis=-1) + op_s = paddle.sort(op_ti, axis=-1) + self.assertTrue( + paddle.equal_all(ref_s, op_s).item(), + f"E={E}/zero_bias: topk_indices set mismatch", + ) + + def test_large_bias_steers_routing(self): + """Large positive bias on first half of experts must dominate selection.""" + E, topk = 128, 8 + paddle.seed(17) + gating = paddle.randn([64, E], dtype=paddle.float32) + bias = paddle.concat( + [ + paddle.full([1, E // 2], 2.0, dtype=paddle.float32), + paddle.full([1, E // 2], -2.0, dtype=paddle.float32), + ], + axis=1, + ) + _, _, ti = grouped_topk(gating, bias, 1, 1, topk, True, 1.0) + self.assertTrue( + (ti < E // 2).all().item(), + "Large positive bias on experts [0, E/2) did not steer all selections there", + ) + + def test_extreme_logits_no_nan_inf(self): + """Very large logits must not produce NaN or Inf in outputs.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + paddle.seed(18) + gating = paddle.randn([8, E], dtype=paddle.float32) * 50.0 + bias = paddle.zeros([1, E], dtype=paddle.float32) + _, tv, _ = grouped_topk(gating, bias, ng, tkg, topk, False, 1.0) + self.assertFalse(paddle.isnan(tv).any().item(), f"E={E}: NaN in topk_values") + self.assertFalse(paddle.isinf(tv).any().item(), f"E={E}: Inf in topk_values") + + def test_single_expert_selected(self): + """topk=1: each token selects exactly one expert; weight == 1.0 with renorm.""" + num_tokens = 16 + gating = paddle.randn([num_tokens, 128], dtype=paddle.float32) + bias = paddle.zeros([1, 128], dtype=paddle.float32) + _, tv, ti = grouped_topk(gating, bias, 1, 1, 1, True, 1.0) + self.assertEqual(tv.shape, [num_tokens, 1]) + self.assertEqual(ti.shape, [num_tokens, 1]) + np.testing.assert_allclose( + tv.numpy(), + np.ones((num_tokens, 1), dtype=np.float32), + atol=1e-5, + err_msg="With topk=1 and renorm=True, each weight should be 1.0", + ) + + def test_sparse_scores_consistency(self): + """Sparse scores tensor: non-zero at selected positions must equal topk_values; zero elsewhere.""" + for E, ng, tkg, topk in [(128, 1, 1, 8), (256, 8, 4, 8)]: + gating = paddle.randn([16, E], dtype=paddle.float32) + bias = paddle.zeros([1, E], dtype=paddle.float32) + s, tv, ti = grouped_topk(gating, bias, ng, tkg, topk, True, 1.0) + gathered = paddle.take_along_axis(s, ti, axis=1) + np.testing.assert_allclose( + gathered.numpy(), + tv.numpy(), + atol=1e-6, + err_msg=f"E={E}: sparse scores at topk positions != topk_values", + ) + nonzero_count = (s != 0).sum(axis=-1) + self.assertTrue( + (nonzero_count == topk).all().item(), + f"E={E}: non-zero count per token != topk", + ) + + def test_irregular_token_counts(self): + """Non-power-of-2 token counts must produce correct shapes and values.""" + irregular_T = [3, 7, 15, 33, 65, 127, 129, 257, 511, 513, 900] + for T in irregular_T: + gating = paddle.randn([T, 128], dtype=paddle.float32) + bias = (paddle.rand([1, 128], dtype=paddle.float32) - 0.5) * 0.1 + _, ref_tv, ref_ti = native_grouped_topk(gating.clone(), bias.clone(), 1, 1, 8, True, 1.0) + _, op_tv, op_ti = grouped_topk(gating.clone(), bias.clone(), 1, 1, 8, True, 1.0) + self.assertEqual(op_tv.shape, [T, 8], f"T={T}: topk_values shape mismatch") + self.assertEqual(op_ti.shape, [T, 8], f"T={T}: topk_indices shape mismatch") + ref_s = paddle.sort(ref_ti, axis=-1) + op_s = paddle.sort(op_ti, axis=-1) + if not paddle.equal_all(ref_s, op_s).item(): + n_diff = (ref_s != op_s).sum().item() + self.fail(f"T={T}: topk_indices mismatch, {n_diff} positions differ") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_noaux_tc_redundant.py b/tests/operators/test_noaux_tc_redundant.py index 60d1aad2a22..f5289e0ab3c 100644 --- a/tests/operators/test_noaux_tc_redundant.py +++ b/tests/operators/test_noaux_tc_redundant.py @@ -1,10 +1,22 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest +from unittest import mock import paddle -from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import ( - moe_topk_select, -) from fastdeploy.model_executor.layers.moe.moe import get_moe_scores @@ -135,15 +147,17 @@ def test_group_topk_using_phi_topk(self): e_score_correction_bias=e_score_correction_bias, ) - topk_values, topk_idx = moe_topk_select( - gating_output=gating_output, - n_group=n_group, - topk_group=topk_group, - top_k=top_k, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - renormalize=renormalize, - ) + with mock.patch.dict("os.environ", {"FD_USE_PHI_MOE_TOPK": "1"}): + new_score, topk_values, topk_idx = get_moe_scores( + gating_output=gating_output, + n_group=n_group, + topk_group=topk_group, + top_k=top_k, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + renormalize=renormalize, + topk_reduce_func=lambda x: x.sum(axis=-1, keepdim=True) + 1e-20, + ) equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item() equal_topk_ids = paddle.allclose( diff --git a/tests/operators/test_speculate_get_target_logits.py b/tests/operators/test_speculate_get_accept_tokens_and_logits.py similarity index 61% rename from tests/operators/test_speculate_get_target_logits.py rename to tests/operators/test_speculate_get_accept_tokens_and_logits.py index 5d930418ae1..70b346f4895 100644 --- a/tests/operators/test_speculate_get_target_logits.py +++ b/tests/operators/test_speculate_get_accept_tokens_and_logits.py @@ -17,7 +17,7 @@ import paddle from fastdeploy.model_executor.layers.sample.ops.speculate_logprob_utils import ( - speculate_get_target_logits, + speculate_get_accept_tokens_and_logits, ) @@ -35,34 +35,38 @@ def test_all_decode(self): seq_lens_encoder = paddle.to_tensor([[0], [0], [0]], dtype="int32") seq_lens_this_time = paddle.to_tensor([[2], [2], [2]], dtype="int32") accept_num = paddle.to_tensor([1, 2, 1], dtype="int32") + accept_tokens = paddle.to_tensor([[10, -1], [20, 21], [30, -1]], dtype="int64") batch_token_num = paddle.where( seq_lens_encoder != 0, paddle.ones_like(seq_lens_encoder), seq_lens_this_time, ).squeeze(1) - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" - ) + cu_seqlens_q_output = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype("int32") cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + token_ids = paddle.full(shape=[accept_num.sum()], fill_value=0, dtype="int64") target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) - speculate_get_target_logits( + speculate_get_accept_tokens_and_logits( + token_ids, target_logits, logits, cu_batch_token_offset, - ori_cu_batch_token_offset, + cu_seqlens_q_output, seq_lens_this_time, seq_lens_encoder, accept_num, + accept_tokens, ) - glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") - glod_logits[0][:] = 0 - glod_logits[1][:] = 2 - glod_logits[2][:] = 3 - glod_logits[3][:] = 4 + ref_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") + ref_logits[0][:] = 0 + ref_logits[1][:] = 2 + ref_logits[2][:] = 3 + ref_logits[3][:] = 4 + ref_token_ids = paddle.to_tensor([10, 20, 21, 30], dtype="int64") - assert paddle.allclose(target_logits, glod_logits) + assert paddle.allclose(target_logits, ref_logits) + assert paddle.equal_all(token_ids, ref_token_ids) def test_partial_decode(self): token_num = 5 @@ -73,34 +77,38 @@ def test_partial_decode(self): seq_lens_encoder = paddle.to_tensor([[10], [0], [0]], dtype="int32") seq_lens_this_time = paddle.to_tensor([[10], [2], [2]], dtype="int32") accept_num = paddle.to_tensor([1, 2, 1], dtype="int32") + accept_tokens = paddle.to_tensor([[10, -1], [20, 21], [30, -1]], dtype="int64") batch_token_num = paddle.where( seq_lens_encoder != 0, paddle.ones_like(seq_lens_encoder), seq_lens_this_time, ).squeeze(1) - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" - ) + cu_seqlens_q_output = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype("int32") cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + token_ids = paddle.full(shape=[accept_num.sum()], fill_value=0, dtype="int64") target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) - speculate_get_target_logits( + speculate_get_accept_tokens_and_logits( + token_ids, target_logits, logits, cu_batch_token_offset, - ori_cu_batch_token_offset, + cu_seqlens_q_output, seq_lens_this_time, seq_lens_encoder, accept_num, + accept_tokens, ) - glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") - glod_logits[0][:] = 0 - glod_logits[1][:] = 1 - glod_logits[2][:] = 2 - glod_logits[3][:] = 3 + ref_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") + ref_logits[0][:] = 0 + ref_logits[1][:] = 1 + ref_logits[2][:] = 2 + ref_logits[3][:] = 3 + ref_token_ids = paddle.to_tensor([10, 20, 21, 30], dtype="int64") - assert paddle.allclose(target_logits, glod_logits) + assert paddle.allclose(target_logits, ref_logits) + assert paddle.equal_all(token_ids, ref_token_ids) def test_all_prefill(self): token_num = 3 @@ -111,33 +119,37 @@ def test_all_prefill(self): seq_lens_encoder = paddle.to_tensor([[10], [10], [10]], dtype="int32") seq_lens_this_time = paddle.to_tensor([[10], [10], [10]], dtype="int32") accept_num = paddle.to_tensor([1, 1, 1], dtype="int32") + accept_tokens = paddle.to_tensor([[10, -1], [20, -1], [30, -1]], dtype="int64") batch_token_num = paddle.where( seq_lens_encoder != 0, paddle.ones_like(seq_lens_encoder), seq_lens_this_time, ).squeeze(1) - ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( - "int32" - ) + cu_seqlens_q_output = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype("int32") cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + token_ids = paddle.full(shape=[accept_num.sum()], fill_value=0, dtype="int64") target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) - speculate_get_target_logits( + speculate_get_accept_tokens_and_logits( + token_ids, target_logits, logits, cu_batch_token_offset, - ori_cu_batch_token_offset, + cu_seqlens_q_output, seq_lens_this_time, seq_lens_encoder, accept_num, + accept_tokens, ) - glod_logits = paddle.full(shape=[3, self.vocab_size], fill_value=-1, dtype="float32") - glod_logits[0][:] = 0 - glod_logits[1][:] = 1 - glod_logits[2][:] = 2 + ref_logits = paddle.full(shape=[3, self.vocab_size], fill_value=-1, dtype="float32") + ref_logits[0][:] = 0 + ref_logits[1][:] = 1 + ref_logits[2][:] = 2 + ref_token_ids = paddle.to_tensor([10, 20, 30], dtype="int64") - assert paddle.allclose(target_logits, glod_logits) + assert paddle.allclose(target_logits, ref_logits) + assert paddle.equal_all(token_ids, ref_token_ids) if __name__ == "__main__": diff --git a/tests/operators/test_speculate_get_token_penalty_multi_scores.py b/tests/operators/test_speculate_get_token_penalty_multi_scores.py index 845f666ee7d..61efdbf270f 100644 --- a/tests/operators/test_speculate_get_token_penalty_multi_scores.py +++ b/tests/operators/test_speculate_get_token_penalty_multi_scores.py @@ -61,7 +61,7 @@ def update_repeat_times( token_ids_all_now = token_ids_all[bi] repeat_times_now = repeat_times[token_idx] - for i in range(length_id): + for i in range(cur_len[bi]): id = token_ids_all_now[i] if id < 0: break diff --git a/tests/operators/test_speculate_set_stop_value_multi_seqs.py b/tests/operators/test_speculate_set_stop_value_multi_seqs.py index 45d8a0ef34f..aa048560c30 100644 --- a/tests/operators/test_speculate_set_stop_value_multi_seqs.py +++ b/tests/operators/test_speculate_set_stop_value_multi_seqs.py @@ -42,7 +42,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: return paddle_inputs -def run_kernel(paddle_inputs, inputs): +def run_kernel(paddle_inputs): """Call the CUDA kernel.""" speculate_set_stop_value_multi_seqs( paddle_inputs["accept_tokens"], @@ -137,7 +137,18 @@ def gen_inputs( def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]: - """Python reference — must match CUDA kernel logic exactly.""" + """Python reference — must match CUDA kernel logic exactly. + + token_ids_all 布局 (新 step_idx 语义): + pre_ids_now[k] = 第 k 个 output token (k >= 0, 0-indexed) + 最后一个 output token 在 pre_ids_now[step_idx - 1] + step_idx = 历史已生成的 token 数量 + + 核心设计: + 1. accept_idx 从 -1 开始,-1 表示检查 pre_ids 末尾(上一轮延迟的情况) + 2. 主循环检查 accept_idx <= accept_num - 2 + 3. 匹配成功时: 保留 stop_seq 所有 token,在其后追加 eos + """ accept_tokens = inputs["accept_tokens"].copy() accept_num = inputs["accept_num"].copy() stop_flags = inputs["stop_flags"].copy() @@ -166,27 +177,36 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str step_idx_now = int(step_idx[bid]) min_token_limit = int(min_tokens[bid]) - can_stop = step_idx_now >= min_token_limit + can_stop = step_idx_now + an >= min_token_limit if not can_stop: continue if stop_flags[bid]: continue - accept_idx = 0 + # CUDA kernel: accept_idx 从 -1 开始,检查 pre_ids 末尾 + accept_idx = -1 is_end = False - while accept_idx <= an - 1 and not is_end: + + # loop_end = accept_num > 0 ? accept_num - 2 : -1 + loop_end = an - 2 if an > 0 else -1 + while accept_idx <= loop_end and not is_end: if step_idx_now + accept_idx + 1 < stop_seq_len: accept_idx += 1 continue - # Check one stop_seq match + # 从后向前匹配 stop_seq 的每个 token for i in range(stop_seq_len - 1, -1, -1): + offset = stop_seq_len - 1 - i + accept_tokens_idx = accept_idx - offset cur_token_idx = -1 - if stop_seq_len - 1 - i < accept_idx: - cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1] + + if accept_tokens_idx >= 0: + cur_token_idx = accept_tokens_now[accept_tokens_idx] else: - pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i) - if pre_ids_idx <= 0: + # 新语义: pre_ids_idx = step_idx_now + accept_tokens_idx + # pre_ids_now[0] 是第 1 个 output token + pre_ids_idx = step_idx_now + accept_tokens_idx + if pre_ids_idx < 0: break cur_token_idx = pre_ids_now[pre_ids_idx] @@ -199,9 +219,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str accept_idx += 1 if is_end: - accept_num[bid] = accept_idx - accept_tokens[bid, accept_idx - 1] = end_ids[0] - # stop_flags[bid] = True # kernel no longer sets stop_flags + # accept_idx 已递增,指向 stop_seq 最后 token 的下一个位置 + # 保留 stop_seq 所有 token,在其后追加 eos + accept_num[bid] = accept_idx + 1 + accept_tokens[bid, accept_idx] = end_ids[0] return { "accept_tokens": accept_tokens, @@ -239,7 +260,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): def _run_and_get(self, inputs): paddle_inputs = to_paddle_inputs(inputs) - run_kernel(paddle_inputs, inputs) + run_kernel(paddle_inputs) return get_outputs(paddle_inputs) def _check_all_outputs(self, inputs, outputs): @@ -264,7 +285,7 @@ def test_configs(self): self._run_full_test(test_cfg) def test_match_in_accept_tokens_only(self): - """Stop seq found entirely within accept_tokens.""" + """Stop seq found entirely within accept_tokens. Eos appended after stop_seq last token.""" inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10) # Place stop seq [A, B, C] at accept_tokens positions [0,1,2] inputs["accept_num"][:] = 4 @@ -276,9 +297,13 @@ def test_match_in_accept_tokens_only(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # stop_seq [10, 20, 30] matches at accept_idx=2 (window ends at accept_tokens[2]=30) + # After loop increment, accept_idx=3, accept_num=4, eos appended at accept_tokens[3] + self.assertEqual(outputs["accept_num"][0], 4) + self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq def test_match_spanning_pre_ids_and_accept(self): - """Stop seq spans token_ids_all (pre_ids) and accept_tokens.""" + """Stop seq spans token_ids_all (pre_ids) and accept_tokens. Eos appended after stop_seq last token.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -290,12 +315,15 @@ def test_match_spanning_pre_ids_and_accept(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 6 inputs["accept_num"][:] = 3 - # Kernel matching at accept_idx=2 (3rd token, 0-indexed): - # i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1] - # i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0] - # i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6] - # So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]] - inputs["token_ids_all"][0, 6] = 99 + # stop_seq = [99, 11, 22] (len=3) + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 6 表示有 6 个历史 output token,在 pre_ids_now[0..5] + # At accept_idx=1 (window ends at accept_tokens[1]=22): + # i=2: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=22 vs stop_seq[2]=22 ✓ + # i=1: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=11 vs stop_seq[1]=11 ✓ + # i=0: offset=2, accept_tokens_idx=-1 -> pre_ids_idx=6+(-1)=5 -> pre_ids[5]=99 vs stop_seq[0]=99 ✓ + inputs["token_ids_all"][0, 5] = 99 # pre_ids_now[5] = 第 6 个 output token (0-indexed) inputs["accept_tokens"][0, :3] = [11, 22, 33] inputs["stop_seqs"][0, 0, :3] = [99, 11, 22] inputs["stop_seqs_len"][0, 0] = 3 @@ -303,12 +331,14 @@ def test_match_spanning_pre_ids_and_accept(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - # Match at accept_idx=2, loop increments to 3 + # Match at accept_idx=1, loop increments to 2 -> accept_num=3, eos at accept_tokens[2] self.assertEqual(outputs["accept_num"][0], 3) - self.assertEqual(outputs["accept_tokens"][0, 2], -1) + self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq - def test_match_in_pre_ids_only(self): - """Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0.""" + def test_match_in_pre_ids_only_not_detected(self): + """Stop seq ending purely in pre_ids history but NOT at the end position. + The kernel only detects stop_seq at the very end of pre_ids via accept_idx=-1 check. + Stop seq placed earlier in pre_ids should not be detected.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -320,15 +350,13 @@ def test_match_in_pre_ids_only(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70 - # stop_seq = [50, 60, 70], all 3 tokens are in pre_ids - # For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check - # i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70 - # i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60 - # i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50 - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 + # 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 8 表示有 8 个历史 output token,在 pre_ids_now[0..7] + # accept_idx=-1 会检查 pre_ids_now[7] 开始的 stop_seq + # 把 stop_seq 放在 pre_ids_now[2,3,4] - 不会被检测到 + inputs["token_ids_all"][0, 2] = 50 + inputs["token_ids_all"][0, 3] = 60 + inputs["token_ids_all"][0, 4] = 70 inputs["accept_tokens"][0, :3] = [1, 2, 3] inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] inputs["stop_seqs_len"][0, 0] = 3 @@ -336,7 +364,8 @@ def test_match_in_pre_ids_only(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - self.assertEqual(outputs["accept_num"][0], 1) + # No match: stop_seq is in pre_ids but not at the end, accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 3) def test_already_stopped(self): """Kernel skips sequences with stop_flags=True.""" @@ -351,7 +380,7 @@ def test_already_stopped(self): np.testing.assert_array_equal(outputs["accept_num"], inputs["accept_num"]) def test_min_tokens_blocks_stop(self): - """Kernel skips stop check when step_idx < min_tokens.""" + """Kernel skips stop check when step_idx + accept_num < min_tokens.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -363,20 +392,24 @@ def test_min_tokens_blocks_stop(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # Same setup that would match (like test_match_in_pre_ids_only) - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 + # Place stop_seq in pre_ids at end position (would be detected by accept_idx=-1) + # pre_ids_now[0..7] = 8 个历史 output token + # accept_idx=-1 检查 pre_ids_now[5,6,7] 对应 stop_seq[0,1,2] + inputs["token_ids_all"][0, 5] = 50 + inputs["token_ids_all"][0, 6] = 60 + inputs["token_ids_all"][0, 7] = 70 inputs["accept_tokens"][0, :3] = [1, 2, 3] inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] inputs["stop_seqs_len"][0, 0] = 3 inputs["stop_flags"][:] = False - inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop + inputs["min_tokens"][:] = 100 # step_idx+accept_num=11 < 100, should NOT stop outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # min_tokens prevents stop, accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 3) def test_min_tokens_allows_stop(self): - """Kernel allows stop when step_idx >= min_tokens.""" + """Kernel allows stop when step_idx + accept_num >= min_tokens.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -388,15 +421,17 @@ def test_min_tokens_allows_stop(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only) - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 - inputs["accept_tokens"][0, :3] = [1, 2, 3] - inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] - inputs["stop_seqs_len"][0, 0] = 3 + # stop_seq [X, 50] spans pre_ids and accept_tokens[0]. + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # At accept_idx=0 (window ends at accept_tokens[0]=50): + # i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=50 vs stop_seq[1]=50 ✓ + # i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=8+(-1)=7 -> pre_ids[7] + pre_val = int(inputs["token_ids_all"][0, 7]) # pre_ids_now[7] + inputs["accept_tokens"][0, :3] = [50, 60, 70] + inputs["stop_seqs"][0, 0, :2] = [pre_val, 50] + inputs["stop_seqs_len"][0, 0] = 2 inputs["stop_flags"][:] = False - inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop + inputs["min_tokens"][:] = 5 # step_idx+accept_num=11 >= 5, should stop outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) @@ -413,20 +448,24 @@ def test_multiple_stop_seqs_second_matches(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # accept_tokens: stop_seq[20,30] matches at accept_idx=2: - # i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK - # i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK + # accept_tokens: [20, 30, 40] + # Second stop seq [20, 30] matches at accept_idx=1 (window ends at accept_tokens[1]=30): + # i=1: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=30 vs stop_seq[1]=30 ✓ + # i=0: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=20 vs stop_seq[0]=20 ✓ inputs["accept_tokens"][0, :3] = [20, 30, 40] # First stop seq doesn't match inputs["stop_seqs"][0, 0, :3] = [99, 98, 97] inputs["stop_seqs_len"][0, 0] = 3 - # Second stop seq matches + # Second stop seq [20, 30] matches inputs["stop_seqs"][0, 1, :2] = [20, 30] inputs["stop_seqs_len"][0, 1] = 2 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # Match at accept_idx=1 -> accept_num=3, eos at accept_tokens[2] + self.assertEqual(outputs["accept_num"][0], 3) + self.assertEqual(outputs["accept_tokens"][0, 2], -1) # eos appended after stop_seq def test_nonzero_prompt_lens(self): """Verify prompt_lens offset is applied correctly.""" @@ -444,19 +483,104 @@ def test_nonzero_prompt_lens(self): inputs["accept_num"][:] = 2 inputs["accept_tokens"][0, :2] = [55, 66] # pre_ids_now starts at token_ids_all[0, prompt_len:] - # stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx] - # For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4 - # -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4] - # For accept_idx=1 (second token is accept_tokens[0,0]=55): - # i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55 - # i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5] - target_val = int(inputs["token_ids_all"][0, prompt_len + 5]) + # pre_ids_now[k] = 第 k 个 output token (k >= 0) + # 新索引公式: pre_ids_idx = step_idx_now + accept_tokens_idx + # stop_seq = [X, 55] where X = pre_ids_now[5 + (-1)] = pre_ids_now[4] + # At accept_idx=0 (window ends at accept_tokens[0]=55): + # i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=55 vs stop_seq[1]=55 ✓ + # i=0: offset=1, accept_tokens_idx=-1 -> pre_ids_idx=5+(-1)=4 -> pre_ids[4]=token_ids_all[0, prompt_len+4] + target_val = int(inputs["token_ids_all"][0, prompt_len + 4]) inputs["stop_seqs"][0, 0, :2] = [target_val, 55] inputs["stop_seqs_len"][0, 0] = 2 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # Match at accept_idx=0 -> accept_num=2, eos at accept_tokens[1] + self.assertEqual(outputs["accept_num"][0], 2) + self.assertEqual(outputs["accept_tokens"][0, 1], -1) # eos appended after stop_seq + + def test_single_token_stop_seq_preserved(self): + """Single token stop_seq (like <|im_end|>) with eos appended after it.""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=90, + ) + inputs["prompt_lens"][:] = 0 + inputs["step_idx"][:] = 10 + inputs["accept_num"][:] = 4 + # accept_tokens: [a, b, <|im_end|>, d] where <|im_end|> has token id 999 + inputs["accept_tokens"][0, :4] = [100, 200, 999, 300] + # stop_seq = [<|im_end|>] (single token) + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # Match at accept_idx=2 (window ends at accept_tokens[2]=999) + # After loop increment, accept_idx=3, accept_num=4, eos at accept_tokens[3] + self.assertEqual(outputs["accept_num"][0], 4) + self.assertEqual(outputs["accept_tokens"][0, 3], -1) # eos appended after stop_seq + + def test_stop_seq_at_last_position_not_detected(self): + """Stop seq at the last position of accept_tokens is NOT detected (deferred to next round).""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=100, + ) + inputs["prompt_lens"][:] = 0 + inputs["step_idx"][:] = 10 + inputs["accept_num"][:] = 4 + # stop_seq [999] is at accept_tokens[3] (last valid position) + # Since we only check up to accept_num - 2 = 2, this won't be detected + inputs["accept_tokens"][0, :4] = [100, 200, 300, 999] + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # No match because accept_idx only goes up to 2, and 999 is at position 3 + # accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 4) + + def test_stop_seq_detected_from_previous_round(self): + """Stop seq at the end of pre_ids (from previous round) is detected via accept_idx=-1.""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=110, + ) + inputs["prompt_lens"][:] = 0 + # 新语义: pre_ids_now[k] = 第 k 个 output token (k >= 0) + # step_idx = 10 表示有 10 个历史 output token,在 pre_ids_now[0..9] + # accept_idx=-1 检查 pre_ids_now[9] (最后一个历史 token) + inputs["step_idx"][:] = 10 + inputs["token_ids_all"][0, 9] = 999 # pre_ids_now[9] = 第 10 个 output token (0-indexed) + inputs["accept_num"][:] = 3 + inputs["accept_tokens"][0, :3] = [100, 200, 300] + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # stop_seq [999] was in pre_ids at end, accept_idx=-1 matches + # After loop increment, accept_idx=0, accept_num=1, eos at accept_tokens[0] + self.assertEqual(outputs["accept_num"][0], 1) + self.assertEqual(outputs["accept_tokens"][0, 0], -1) # replaced with eos if __name__ == "__main__": diff --git a/tests/operators/test_tritonmoe_preprocess.py b/tests/operators/test_tritonmoe_preprocess.py index 94d85c956e1..7071e275225 100644 --- a/tests/operators/test_tritonmoe_preprocess.py +++ b/tests/operators/test_tritonmoe_preprocess.py @@ -12,12 +12,159 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Correctness tests for tritonmoe_preprocess +========================================== + +Tests the fastdeploy wrapper: + tritonmoe_preprocess(topk_ids, num_experts, block_size) + -> (sorted_token_ids, expert_ids, num_tokens_post_padded) + +The verification approach mirrors FlagTree/python/tutorials/tle/02-moe_align_block_size.py: + - Use paddle.bincount as an independent reference (no second kernel to cross-compare). + - Validate three dimensions: + 1. num_tokens_post_padded – total token count after per-expert block alignment + 2. expert_ids – each block is mapped to the correct expert + 3. sorted_token_ids – every token is routed to the right expert's slot, + and padding slots carry sentinel values >= num_tokens +""" + import unittest import numpy as np import paddle -from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess +# --------------------------------------------------------------------------- +# Import guard – skip entire module when CUDA is unavailable or +# fastdeploy is not installed (e.g. CPU-only CI environments). +# --------------------------------------------------------------------------- +try: + from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess + + _AVAILABLE = paddle.device.is_compiled_with_cuda() +except Exception: + _AVAILABLE = False + +DEVICE = "gpu" + +# 仅对小规模 case 打印详细 tensor,超过此阈值只打印统计摘要 +_PRINT_TENSOR_NUMEL_LIMIT = 64 + + +def _fmt_tensor(t: paddle.Tensor, name: str) -> str: + t_cpu = t.cpu() + if t_cpu.numel() <= _PRINT_TENSOR_NUMEL_LIMIT: + return f"{name}{list(t_cpu.shape)} = {t_cpu.tolist()}" + return ( + f"{name}{list(t_cpu.shape)} | " + f"min={int(t_cpu.min())} max={int(t_cpu.max())} " + f"mean={float(t_cpu.cast('float32').mean()):.2f} numel={t_cpu.numel()}" + ) + + +# --------------------------------------------------------------------------- +# Reference helpers (CPU, independent of the kernel under test) +# --------------------------------------------------------------------------- + + +def _ref_counts_and_cumsum(topk_ids_flat: paddle.Tensor, num_experts: int, block_size: int): + """ + Compute per-expert token counts and the cumulative sum of block-aligned counts. + + Returns: + counts : int32 tensor of shape (num_experts,) + cumsum : int32 tensor of shape (num_experts,) – cumulative aligned counts + """ + # Only consider valid expert ids [0, num_experts); ignore -1 (EP filtered) + valid_mask = (topk_ids_flat >= 0) & (topk_ids_flat < num_experts) + valid_ids = topk_ids_flat[valid_mask] + counts = paddle.bincount(valid_ids.cast("int64"), minlength=num_experts).cast("int32") + aligned = ((counts + block_size - 1) // block_size) * block_size + cumsum = paddle.cumsum(aligned, axis=0).cast("int32") + return counts, cumsum + + +# --------------------------------------------------------------------------- +# Core verification logic (shared across all test cases) +# --------------------------------------------------------------------------- + + +def _verify(topk_ids: paddle.Tensor, block_size: int, num_experts: int, label: str = ""): + """ + Run tritonmoe_preprocess and verify all three output tensors. + topk_ids may be 1-D or 2-D; dtype int32 or int64. + Prints inputs, golden references, kernel outputs, and per-check comparison. + """ + tag = f"[{label}] " if label else "" + sep = "=" * 70 + + sorted_token_ids, expert_ids, num_tokens_post_pad = tritonmoe_preprocess(topk_ids, num_experts, block_size) + + topk_ids_flat = topk_ids.flatten().cast("int64").cpu() + num_tokens = topk_ids_flat.numel() + + counts, cumsum = _ref_counts_and_cumsum(topk_ids_flat, num_experts, block_size) + aligned = ((counts + block_size - 1) // block_size) * block_size + valid_length = int(cumsum[-1].item()) + num_blocks = valid_length // block_size + + expected_expert_ids = paddle.repeat_interleave( + paddle.arange(num_experts, dtype="int32"), # CPU + (aligned // block_size).cast("int32"), + ) + + np.testing.assert_array_equal( + num_tokens_post_pad.cpu().numpy(), + cumsum[-1:].cpu().numpy(), + ) + + # ------------------------------------------------------------------ # + # Check 2: expert_ids – each block maps to the expected expert # + # ------------------------------------------------------------------ # + got_eids = expert_ids[:num_blocks].cpu() + want_eids = expected_expert_ids.cpu() + np.testing.assert_array_equal( + got_eids.numpy(), + want_eids.numpy(), + ) + + # ------------------------------------------------------------------ # + # Check 3: sorted_token_ids – routing correctness per expert # + # ------------------------------------------------------------------ # + + start = 0 + for expert_id in range(num_experts): + end = int(cumsum[expert_id].item()) + tokens = sorted_token_ids[start:end].cpu() + valid_tokens = tokens[tokens < num_tokens] + # padding_tokens = tokens[tokens >= num_tokens] + + want_count = int(counts[expert_id].item()) + got_count = valid_tokens.numel() + count_ok = got_count == want_count + + assert count_ok, f"expert {expert_id}: expected {want_count} valid tokens, got {got_count}" + if counts[expert_id] > 0: + np.testing.assert_array_equal( + topk_ids_flat[valid_tokens.cast("int64")].numpy(), + paddle.full_like(valid_tokens, expert_id).numpy(), + ) + start = end + + # padding 区域哨兵检查 + if valid_length < sorted_token_ids.numel(): + padding_region = sorted_token_ids[valid_length:].cpu() + sentinel_ok = paddle.all(padding_region >= num_tokens).item() + + assert sentinel_ok, "padding slots beyond valid_length contain non-sentinel values" + + print(f"\n{tag}ALL CHECKS PASSED") + print(sep) + + +# --------------------------------------------------------------------------- +# Original unittest-based tests (kept for backward compatibility) +# --------------------------------------------------------------------------- class TestTritonMOEPreprocess(unittest.TestCase): @@ -35,10 +182,14 @@ def _check_output_shapes( self, sorted_ids, expert_ids, num_tokens_post_pad, topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M ): """Check output shapes and dtypes""" - expected_max_num_tokens_padded = topk_ids_np.size + num_experts * (GEMM_BLOCK_SIZE_M - 1) + if topk_ids_np.size < num_experts + 1: + expected_max_num_tokens_padded = topk_ids_np.size * GEMM_BLOCK_SIZE_M + else: + expected_max_num_tokens_padded = topk_ids_np.size + (num_experts + 1) * (GEMM_BLOCK_SIZE_M - 1) + self.assertEqual(sorted_ids.shape[0], expected_max_num_tokens_padded) - expected_max_num_m_blocks = expected_max_num_tokens_padded // GEMM_BLOCK_SIZE_M + expected_max_num_m_blocks = (expected_max_num_tokens_padded + GEMM_BLOCK_SIZE_M - 1) // GEMM_BLOCK_SIZE_M self.assertEqual(expert_ids.shape[0], expected_max_num_m_blocks) self.assertEqual(num_tokens_post_pad.shape[0], 1) @@ -104,17 +255,232 @@ def test_basic_case(self): ) self._check_output_values_basic(sorted_ids, expert_ids, num_tokens_post_pad) - def test_unsupported_num_experts(self): - """Test unsupported num_experts raises OSError""" - topk_ids_np = np.array([[0, 1], [1, 0]], dtype=np.int64) - unsupported_experts = [3, 9, 65, 129] - GEMM_BLOCK_SIZE_M = 4 - for num_experts in unsupported_experts: - with self.subTest(num_experts=num_experts): - with self.assertRaises(OSError): - self._run_op(topk_ids_np, num_experts, GEMM_BLOCK_SIZE_M) +# --------------------------------------------------------------------------- +# Correctness tests (ported from test_moe_align_block_size.py) +# --------------------------------------------------------------------------- + + +class TestTritonMoePreprocessBasic(unittest.TestCase): + """Basic / small cases – easy to reason about manually.""" + + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def test_docstring_example(self): + """Reproduce the example from the function docstring.""" + topk_ids = paddle.to_tensor([[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], dtype="int64") + _verify(topk_ids, block_size=4, num_experts=5, label="docstring_example") + + def test_single_token_single_expert(self): + """Minimal input: one token assigned to one expert.""" + topk_ids = paddle.to_tensor([[0]], dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="single_token_single_expert") + + def test_all_tokens_same_expert(self): + """All tokens go to expert 0 – only one expert's slot is used.""" + topk_ids = paddle.zeros((64, 1), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="all_tokens_same_expert") + + def test_uniform_1d(self): + """1-D topk_ids (top_k=1 squeezed) with uniform distribution.""" + paddle.seed(42) + topk_ids = paddle.randint(0, 8, (128,), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="uniform_1d") + + def test_topk_equals_num_experts(self): + """Every token selects all experts (top_k == num_experts).""" + num_experts = 4 + topk_ids = paddle.arange(num_experts, dtype="int64").unsqueeze(0).expand((8, num_experts)) + _verify(topk_ids, block_size=4, num_experts=num_experts, label="topk_equals_num_experts") + + def test_num_tokens_less_than_num_experts(self): + """Fewer tokens than experts – exercises the small-input branch.""" + topk_ids = paddle.to_tensor([[0], [3]], dtype="int64") + _verify(topk_ids, block_size=16, num_experts=64, label="num_tokens_less_than_num_experts") + + def test_exact_block_boundary(self): + """Token count per expert is exactly block_size – no padding needed.""" + block_size = 16 + num_experts = 4 + topk_ids = paddle.concat([paddle.full((block_size,), e, dtype="int64") for e in range(num_experts)]) + _verify(topk_ids, block_size=block_size, num_experts=num_experts, label="exact_block_boundary") + + def test_block_size_1(self): + """block_size=1 means no padding is ever added.""" + paddle.seed(0) + topk_ids = paddle.randint(0, 16, (64,), dtype="int64") + _verify(topk_ids, block_size=1, num_experts=16, label="block_size_1") + + +class TestTritonMoePreprocessEdgeCases(unittest.TestCase): + """Edge / boundary cases.""" + + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def test_empty_topk_ids(self): + """Zero-token input should not crash; num_tokens_post_pad == 0.""" + topk_ids = paddle.empty((0,), dtype="int64").cuda() + sorted_ids, expert_ids_out, num_post = tritonmoe_preprocess(topk_ids, 8, 16) + got = int(num_post.item()) + + self.assertEqual(got, 0) + + def test_one_expert(self): + """Single expert: all tokens must end up in expert 0's bucket.""" + paddle.seed(1) + topk_ids = paddle.zeros((32,), dtype="int64") + _verify(topk_ids, block_size=8, num_experts=1, label="one_expert") + + def test_large_block_size(self): + """block_size larger than total tokens.""" + topk_ids = paddle.randint(0, 4, (8,), dtype="int64") + _verify(topk_ids, block_size=128, num_experts=4, label="large_block_size") + + def test_int64_dtype(self): + """topk_ids in int64 – the kernel should handle dtype conversion.""" + paddle.seed(7) + topk_ids = paddle.randint(0, 8, (64, 2), dtype="int64") + _verify(topk_ids, block_size=16, num_experts=8, label="int64_dtype") + + +class TestTritonMoePreprocessRealistic(unittest.TestCase): + """Larger, more realistic MoE shapes.""" + def setUp(self): + if not _AVAILABLE: + self.skipTest("CUDA or fastdeploy not available") + + def _run_uniform_distribution(self, num_tokens, num_experts, block_size): + """Uniform random token-to-expert assignment across common MoE shapes.""" + paddle.seed(0) + topk_ids = paddle.randint(0, num_experts, (num_tokens,), dtype="int64") + _verify( + topk_ids, + block_size=block_size, + num_experts=num_experts, + label=f"uniform_T{num_tokens}_E{num_experts}_B{block_size}", + ) + + def test_uniform_distribution(self): + """Uniform random token-to-expert assignment across common MoE shapes.""" + for num_tokens, num_experts, block_size in [ + (256, 8, 16), + (1024, 16, 16), + (4096, 64, 16), + (8192, 64, 32), + (8192, 128, 64), + (16384, 128, 128), + (16384, 256, 128), + (16384, 512, 256), + (32768, 512, 256), + (32768, 512, 64), + (163840, 1024, 256), + ]: + with self.subTest(num_tokens=num_tokens, num_experts=num_experts, block_size=block_size): + self._run_uniform_distribution(num_tokens, num_experts, block_size) + + def _run_topk_2d(self, num_tokens, top_k, num_experts, block_size): + """2-D topk_ids as produced by the router (shape [num_tokens, top_k]).""" + paddle.seed(0) + topk_ids = paddle.randint(0, num_experts, (num_tokens, top_k), dtype="int64") + _verify( + topk_ids, + block_size=block_size, + num_experts=num_experts, + label=f"topk2d_T{num_tokens}_K{top_k}_E{num_experts}_B{block_size}", + ) + + def test_topk_2d(self): + """2-D topk_ids as produced by the router (shape [num_tokens, top_k]).""" + for num_tokens, top_k, num_experts, block_size in [ + (512, 2, 8, 16), + (1024, 4, 16, 16), + (2048, 8, 64, 16), + ]: + with self.subTest(num_tokens=num_tokens, top_k=top_k, num_experts=num_experts, block_size=block_size): + self._run_topk_2d(num_tokens, top_k, num_experts, block_size) + + def _run_zipf_distribution(self, alpha): + """Skewed (Zipf) token distribution – simulates real MoE load imbalance.""" + num_tokens, num_experts, block_size = 8192, 64, 16 + ranks = paddle.arange(1, num_experts + 1, dtype="float32") + probs = 1.0 / ranks**alpha + probs = probs / probs.sum() + paddle.seed(0) + topk_ids = paddle.multinomial(probs, num_tokens, replacement=True).cast("int64") + _verify(topk_ids, block_size=block_size, num_experts=num_experts, label=f"zipf_alpha{alpha}") + + def test_zipf_distribution(self): + """Skewed (Zipf) token distribution – simulates real MoE load imbalance.""" + for alpha in [0.5, 1.2, 2.0]: + with self.subTest(alpha=alpha): + self._run_zipf_distribution(alpha) + + def test_deterministic_with_fixed_seed(self): + """Same seed must produce the same outputs (kernel is deterministic).""" + num_tokens, num_experts, block_size = 4096, 64, 16 + + paddle.seed(99) + topk_ids = paddle.randint(0, num_experts, (num_tokens,), dtype="int64").cuda() + s1, e1, n1 = tritonmoe_preprocess(topk_ids, num_experts, block_size) + + paddle.seed(99) + topk_ids2 = paddle.randint(0, num_experts, (num_tokens,), dtype="int64").cuda() + s2, e2, n2 = tritonmoe_preprocess(topk_ids2, num_experts, block_size) + + valid = int(n1.item()) + + np.testing.assert_array_equal(n1.numpy(), n2.numpy()) + np.testing.assert_array_equal(e1[: valid // block_size].numpy(), e2[: valid // block_size].numpy()) + np.testing.assert_array_equal(paddle.sort(s1[:valid]).numpy(), paddle.sort(s2[:valid]).numpy()) + + +# --------------------------------------------------------------------------- +# Direct-run entry point (python test_tritonmoe_preprocess.py) +# --------------------------------------------------------------------------- if __name__ == "__main__": - unittest.main() + if not _AVAILABLE: + print("SKIP: CUDA or fastdeploy not available.") + else: + basic = TestTritonMoePreprocessBasic() + basic.test_docstring_example() + basic.test_single_token_single_expert() + basic.test_all_tokens_same_expert() + basic.test_uniform_1d() + basic.test_topk_equals_num_experts() + basic.test_num_tokens_less_than_num_experts() + basic.test_exact_block_boundary() + basic.test_block_size_1() + + edge = TestTritonMoePreprocessEdgeCases() + edge.test_empty_topk_ids() + edge.test_one_expert() + edge.test_large_block_size() + edge.test_int64_dtype() + + real = TestTritonMoePreprocessRealistic() + for num_tokens, num_experts, block_size in [ + (256, 8, 16), + (1024, 16, 16), + (4096, 64, 16), + (8192, 64, 32), + (8192, 128, 64), + (16384, 256, 128), + ]: + real._run_uniform_distribution(num_tokens, num_experts, block_size) + for num_tokens, top_k, num_experts, block_size in [ + (512, 2, 8, 16), + (1024, 4, 16, 16), + (2048, 8, 64, 16), + ]: + real._run_topk_2d(num_tokens, top_k, num_experts, block_size) + for alpha in [0.5, 1.2, 2.0]: + real._run_zipf_distribution(alpha) + real.test_deterministic_with_fixed_seed() + + print("\n*** All direct-run tests passed ***") diff --git a/tests/operators/test_unified_update_model_status.py b/tests/operators/test_unified_update_model_status.py index 56656fdbe75..ed97aa86879 100644 --- a/tests/operators/test_unified_update_model_status.py +++ b/tests/operators/test_unified_update_model_status.py @@ -261,7 +261,9 @@ def reference_impl(inputs: Dict[str, Any]) -> Dict[str, Any]: # Write history to token_ids_all (forward loop, mirrors kernel step 5) if output_len > 0: base_addr = int(prompt_lens[batch_id]) - base = cur_step_idx - output_len + 1 + # 新语义: step_idx 入口 = 历史数量,处理后 cur_step_idx = 历史 + output_len + # 第一个 output token 写入位置 = cur_step_idx - output_len + base = cur_step_idx - output_len for i in range(output_len): write_idx = base_addr + base + i if 0 <= write_idx < max_model_len: diff --git a/tests/output/test_process_batch_draft_tokens.py b/tests/output/test_process_batch_draft_tokens.py index 3686dd1b64b..eef5df62cc9 100644 --- a/tests/output/test_process_batch_draft_tokens.py +++ b/tests/output/test_process_batch_draft_tokens.py @@ -30,6 +30,8 @@ def setUp(self): # 模拟 cfg cfg = MagicMock() cfg.speculative_config = MagicMock() + cfg.parallel_config.local_data_parallel_id = 0 + cfg.parallel_config.engine_worker_queue_port = ["9700"] cfg.speculative_config.method = "mtp" cfg.speculative_config.num_speculative_tokens = 3 cfg.model_config = MagicMock() diff --git a/tests/output/test_process_batch_output.py b/tests/output/test_process_batch_output.py index 46282cd386a..04aa08935af 100644 --- a/tests/output/test_process_batch_output.py +++ b/tests/output/test_process_batch_output.py @@ -21,7 +21,7 @@ import paddle -from fastdeploy.engine.request import RequestMetrics, RequestOutput +from fastdeploy.engine.request import RequestMetrics, RequestOutput, RequestStatus from fastdeploy.output.token_processor import TokenProcessor paddle.set_device("cpu") @@ -65,6 +65,7 @@ class CacheConfig: model_config = ModelConfig() scheduler_config = SchedulerConfig() cache_config = CacheConfig() + routing_replay_config = MagicMock(enable_routing_replay=False) class MockTask: @@ -81,6 +82,7 @@ def __init__(self): self.ic_req_data = {} self.prompt_token_ids_len = 0 self.trace_carrier = {} + self.status = RequestStatus.RUNNING_DECODE now = time.time() self.metrics = RequestMetrics( @@ -166,6 +168,8 @@ def setup_token_processor(self, speculative_decoding=False, use_logprobs=False): processor.total_step_per_request = {} processor.accept_token_num_per_head_per_request = {} processor.accept_token_num_per_head = [0] * MAX_DRAFT_TOKENS + processor.use_sampling_mask = False + processor._benchmark_logger = None # processor._recycle_resources = Mock() @@ -207,8 +211,9 @@ def test_speculative_decoding_use_logprobs(self): # stop_flag processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) - # mtype target = 3, decode = 4 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) + # meta[1] packs mtype (low 8 bits) and actual_topk (high 16 bits) + actual_topk = K + 1 + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | (actual_topk << 8))) # batch processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) # accept_num @@ -240,12 +245,12 @@ def test_speculative_decoding_use_logprobs(self): assert len(request_output.outputs.token_ids) == accept_num[i] assert len(request_output.outputs.top_logprobs) == 3 # tokens, scores, ranks - assert len(request_output.outputs.top_logprobs[0][0]) == K + 1 - assert len(request_output.outputs.top_logprobs[1][0]) == K + 1 + assert len(request_output.outputs.top_logprobs[0][0]) == actual_topk + assert len(request_output.outputs.top_logprobs[1][0]) == actual_topk assert len(request_output.outputs.top_logprobs[2]) == accept_num[i] # mtype = 4 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4)) + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(4 | (actual_topk << 8))) processor._process_batch_output() cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens for c in cached_generated_tokens.cache: @@ -254,8 +259,8 @@ def test_speculative_decoding_use_logprobs(self): assert len(request_output.outputs.top_logprobs) == 3 assert len(request_output.outputs.draft_top_logprobs) == 3 # tokens, scores, ranks - assert len(request_output.outputs.draft_top_logprobs[0][0]) == K + 1 - assert len(request_output.outputs.draft_top_logprobs[1][0]) == K + 1 + assert len(request_output.outputs.draft_top_logprobs[0][0]) == actual_topk + assert len(request_output.outputs.draft_top_logprobs[1][0]) == actual_topk assert len(request_output.outputs.draft_top_logprobs[2]) == accept_num[i] def test_process_batch_output_aborted_task_negative_token_speculative_decoding(self): @@ -278,8 +283,8 @@ def test_process_batch_output_aborted_task_negative_token_speculative_decoding(s # Set up output tokens with negative token # stop_flag processor.output_tokens[0, 0].set_tensor(paddle.to_tensor(2)) - # mtype target = 3 - processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3)) + # mtype target = 3, actual_topk packed in high bits + processor.output_tokens[1, 0].set_tensor(paddle.to_tensor(3 | ((K + 1) << 8))) # batch = 2 (so batch_id=0 is < batch_size-1=1) processor.output_tokens[2, 0].set_tensor(paddle.to_tensor(2)) # Set accept_num = PREEMPTED_TOKEN_ID (-9) for first task to trigger abort logic diff --git a/tests/output/test_process_batch_output_use_zmq.py b/tests/output/test_process_batch_output_use_zmq.py index 07826e6f0eb..8244bb06bbf 100644 --- a/tests/output/test_process_batch_output_use_zmq.py +++ b/tests/output/test_process_batch_output_use_zmq.py @@ -31,6 +31,7 @@ def setUp(self): self.cfg.model_config.enable_logprob = True self.cfg.speculative_config.method = None self.cfg.parallel_config.local_data_parallel_id = 0 + self.cfg.parallel_config.engine_worker_queue_port = ["9700"] self.cached_generated_tokens = MagicMock() self.engine_worker_queue = MagicMock() self.split_connector = MagicMock() diff --git a/tests/output/test_token_processor.py b/tests/output/test_token_processor.py index c0609094a2b..ca63c17c903 100644 --- a/tests/output/test_token_processor.py +++ b/tests/output/test_token_processor.py @@ -25,7 +25,12 @@ import pytest from fastdeploy import envs -from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput +from fastdeploy.engine.request import ( + Request, + RequestMetrics, + RequestOutput, + RequestStatus, +) from fastdeploy.output import token_processor from fastdeploy.output.token_processor import ( MAX_BSZ, @@ -64,6 +69,7 @@ def __init__( ) self.max_num_seqs = max_num_seqs self.splitwise_version = "v1" + self.routing_replay_config = types.SimpleNamespace(enable_routing_replay=False) class _DummyResourceManager: @@ -601,8 +607,8 @@ def test_recycle_resources_prefill_failure_sets_error(): with mock.patch.object(envs, "ENABLE_V1_KVCACHE_SCHEDULER", False): processor._recycle_resources(task_id, 0, task, result, is_prefill=True) - assert result.error_code == 400 - assert "failed" in result.error_message + assert result.error_code == 501 + assert "failed" in result.error_msg assert connector.calls and connector.calls[0][1][0] is result @@ -670,6 +676,7 @@ def test_process_batch_output_consumes_tokens_and_finishes_task(): prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) @@ -707,11 +714,13 @@ def test_process_batch_output_logprob_records_topk_and_caching(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task - processor.output_tokens[1, 0] = 1 + # mtext[1] packs bsz (low 16 bits) | actual_topk (high 16 bits) + processor.output_tokens[1, 0] = 1 | ((K + 1) << 16) token_block = np.arange(K + 1, dtype=np.int64) + 3 processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(token_block.reshape([-1, 1])) processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32") @@ -740,7 +749,7 @@ def test_process_batch_output_speculative_logprob_handles_draft_batch(): ) processor._batch_result_buffer = [target] processor.cached_generated_tokens = mock.Mock() - processor.output_tokens[1, 0] = 4 + processor.output_tokens[1, 0] = 4 | ((K + 1) << 8) processor.output_tokens[2, 0] = 1 processor.output_tokens[3, 0] = 1 @@ -783,6 +792,7 @@ def test_process_batch_output_speculative_recovery_stop_finishes(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -833,7 +843,8 @@ def test_process_batch_output_prefill_chunk_and_adapter_skip(): task.get = lambda key, default=None: getattr(task, key, default) rm.tasks_list[0] = task rm.req_dict[task.request_id] = task - processor.output_tokens[1, 0] = 1 + # mtext[1] packs bsz (low 16 bits) | actual_topk (high 16 bits) + processor.output_tokens[1, 0] = 1 | ((K + 1) << 16) processor.output_tokens[2 : 2 + K + 1] = paddle.to_tensor(np.ones([K + 1, 1], dtype=np.int64)) processor.output_scores[: K + 1] = paddle.ones([K + 1, 1], dtype="float32") processor.output_ranks[0] = paddle.to_tensor(0, dtype="int64") @@ -910,11 +921,12 @@ def test_process_batch_output_speculative_logprob_targets_topk_scores(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task rm.req_dict[task.request_id] = task - processor.output_tokens[1, 0] = 3 + processor.output_tokens[1, 0] = 3 | ((K + 1) << 8) processor.output_tokens[2, 0] = 1 processor.output_tokens[3, 0] = 2 token_block = np.arange(MAX_DRAFT_TOKENS * (K + 1), dtype=np.int64).reshape([-1, 1]) + 3 @@ -1075,6 +1087,7 @@ def test_process_batch_output_records_second_decode_token(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.metrics.inference_start_time = time.time() @@ -1144,6 +1157,7 @@ def test_process_batch_output_prefill_sets_draft_tokens(): num_total_tokens=1, block_tables=[1], get=lambda key, default=None: None, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None rm.tasks_list[0] = task @@ -1185,6 +1199,7 @@ def test_process_batch_output_logs_recovery_stop_for_non_speculative(): prompt_token_ids_len=0, num_total_tokens=1, block_tables=[1], + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda k, d=None: getattr(task, k, d) @@ -1222,6 +1237,7 @@ def test_process_batch_output_sets_multimodal_token_counts(): num_total_tokens=1, block_tables=[1], multimodal_inputs={"num_input_image_tokens": 4, "num_input_video_tokens": 5}, + status=RequestStatus.RUNNING_DECODE, ) task.trace_carrier = None task.get = lambda key, default=None: getattr(task, key, default) diff --git a/tests/output/test_token_processor_trace_print.py b/tests/output/test_token_processor_trace_print.py index 9ba9b45dfae..d43183705fb 100644 --- a/tests/output/test_token_processor_trace_print.py +++ b/tests/output/test_token_processor_trace_print.py @@ -23,6 +23,9 @@ class TestTokenProcessorMetrics: def setup_method(self): self.mock_cfg = MagicMock() + self.mock_cfg.parallel_config.local_data_parallel_id = 0 + self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] + self.mock_cfg.scheduler_config.splitwise_role = "decode" self.mock_cached_tokens = MagicMock() self.mock_engine_queue = MagicMock() self.mock_split_connector = MagicMock() @@ -74,9 +77,10 @@ def test_record_completion_metrics(self, caplog): with caplog.at_level(logging.INFO): self.processor._record_completion_metrics(self.task, current_time) - assert len(caplog.records) == 2 + assert len(caplog.records) == 3 assert "[request_id=test123]" in caplog.text assert "[event=INFERENCE_END]" in caplog.text + assert "[event=DECODE_INFERENCE_END]" in caplog.text assert "[event=POSTPROCESSING_START]" in caplog.text # Verify metrics are updated diff --git a/tests/quantization/test_modelopt_nvfp4.py b/tests/quantization/test_modelopt_nvfp4.py index 3bf4653c725..27b5ac1309a 100644 --- a/tests/quantization/test_modelopt_nvfp4.py +++ b/tests/quantization/test_modelopt_nvfp4.py @@ -22,6 +22,9 @@ import paddle +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda *args, **kwargs: None + import fastdeploy.model_executor.layers.quantization.nvfp4 as nvfp4_module from fastdeploy.model_executor.layers.linear import QKVParallelLinear from fastdeploy.model_executor.layers.moe import FusedMoE @@ -133,7 +136,7 @@ def test_module_import_with_flashinfer(self): """Test module reloading when flashinfer is available.""" mock_flashinfer = types.ModuleType("flashinfer") with mock.patch.dict(sys.modules, {"flashinfer": mock_flashinfer}): - with mock.patch("paddle.compat.enable_torch_proxy"): + with mock.patch("paddle.enable_compat"): importlib.reload(nvfp4_module) diff --git a/tests/rl/test_dynamic_weight_gdr.py b/tests/rl/test_dynamic_weight_gdr.py new file mode 100644 index 00000000000..f6b1be0baad --- /dev/null +++ b/tests/rl/test_dynamic_weight_gdr.py @@ -0,0 +1,574 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import importlib.util +import sys +import types +import unittest +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from unittest.mock import MagicMock, patch + +_DYNAMIC_WEIGHT_MODULE = None + + +def _install_dynamic_weight_manager_stubs(): + """Install minimal stubs so this unit test can run without Paddle installed.""" + + def no_grad(): + def decorator(func): + return func + + return decorator + + fake_paddle = types.SimpleNamespace( + Tensor=object, + no_grad=no_grad, + distributed=types.SimpleNamespace( + get_world_size=lambda: 1, + get_rank=lambda: 0, + barrier=lambda *args, **kwargs: None, + restart_process_group=lambda *args, **kwargs: None, + shutdown_process_group=lambda *args, **kwargs: None, + ), + device=types.SimpleNamespace( + cuda=types.SimpleNamespace( + synchronize=lambda: None, + empty_cache=lambda: None, + max_memory_allocated=lambda: 0, + max_memory_reserved=lambda: 0, + memory_allocated=lambda: 0, + memory_reserved=lambda: 0, + ) + ), + base=types.SimpleNamespace( + core=types.SimpleNamespace(LoDTensor=types.SimpleNamespace(_new_shared_cuda=MagicMock())) + ), + load=MagicMock(), + empty=MagicMock(), + to_tensor=MagicMock(), + ) + fake_logger = types.SimpleNamespace( + info=MagicMock(), + warning=MagicMock(), + error=MagicMock(), + debug=MagicMock(), + ) + fake_fastdeploy = types.ModuleType("fastdeploy") + fake_fastdeploy.__path__ = [] + fake_config = types.ModuleType("fastdeploy.config") + fake_config.FDConfig = object + fake_model_executor = types.ModuleType("fastdeploy.model_executor") + fake_model_executor.__path__ = [] + fake_model_executor_utils = types.ModuleType("fastdeploy.model_executor.utils") + fake_model_executor_utils.process_final_after_loading = MagicMock() + fake_numpy = types.ModuleType("numpy") + fake_envs = types.ModuleType("fastdeploy.envs") + fake_envs.FD_USE_GDR_CHECKPOINT_TRANSFER = False + fake_inter_communicator = types.ModuleType("fastdeploy.inter_communicator") + fake_inter_communicator.KVCacheStatus = types.SimpleNamespace() + fake_inter_communicator.ModelWeightsStatus = types.SimpleNamespace(NORMAL=0, CLEARED=1) + fake_yaml = types.ModuleType("yaml") + fake_yaml.safe_load = MagicMock(return_value={}) + fake_yaml.YAMLError = Exception + + sys.modules.update( + { + "paddle": fake_paddle, + "numpy": fake_numpy, + "yaml": fake_yaml, + "paddleformers": types.ModuleType("paddleformers"), + "paddleformers.utils": types.ModuleType("paddleformers.utils"), + "paddleformers.utils.log": types.SimpleNamespace(logger=fake_logger), + "fastdeploy": fake_fastdeploy, + "fastdeploy.envs": fake_envs, + "fastdeploy.config": fake_config, + "fastdeploy.model_executor": fake_model_executor, + "fastdeploy.model_executor.utils": fake_model_executor_utils, + "fastdeploy.inter_communicator": fake_inter_communicator, + } + ) + + +def _load_dynamic_weight_manager_from_file(): + module_path = Path(__file__).resolve().parents[2] / "fastdeploy" / "rl" / "dynamic_weight_manager.py" + spec = importlib.util.spec_from_file_location("dynamic_weight_manager_under_test", module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _load_dynamic_weight_manager_module(): + global _DYNAMIC_WEIGHT_MODULE + if _DYNAMIC_WEIGHT_MODULE is not None: + return _DYNAMIC_WEIGHT_MODULE + + fastdeploy_module = sys.modules.get("fastdeploy") + if fastdeploy_module is not None and not hasattr(fastdeploy_module, "__path__"): + _DYNAMIC_WEIGHT_MODULE = _load_dynamic_weight_manager_from_file() + return _DYNAMIC_WEIGHT_MODULE + + try: + from fastdeploy.rl import dynamic_weight_manager + + _DYNAMIC_WEIGHT_MODULE = dynamic_weight_manager + return dynamic_weight_manager + except ModuleNotFoundError as exc: + if exc.name not in ("numpy", "paddle", "yaml"): + raise + + for name in list(sys.modules): + if name == "fastdeploy" or name.startswith("fastdeploy."): + sys.modules.pop(name, None) + _install_dynamic_weight_manager_stubs() + + _DYNAMIC_WEIGHT_MODULE = _load_dynamic_weight_manager_from_file() + return _DYNAMIC_WEIGHT_MODULE + + +class _FakeModel: + def __init__(self): + self.loaded = [] + self.params = {} + + def load_weights(self, weights_iterator): + self.loaded.extend(list(weights_iterator)) + + def state_dict(self): + return self.params + + +class _FakeMTPModel(_FakeModel): + def __init__(self, mtp_start_layer_idx=2, num_mtp_layers=1): + super().__init__() + self.mtp_start_layer_idx = mtp_start_layer_idx + self.num_mtp_layers = num_mtp_layers + + +def _make_manager(rsync_config=None, load_strategy="rsync"): + DynamicWeightManager = _load_dynamic_weight_manager_module().DynamicWeightManager + + manager = object.__new__(DynamicWeightManager) + fd_config = MagicMock() + fd_config.load_config.rsync_config = rsync_config or { + "backend": "mooncake", + "output_framework": "paddle", + } + fd_config.load_config.load_strategy = load_strategy + fd_config.parallel_config.data_parallel_rank = 2 + fd_config.parallel_config.data_parallel_size = 1 + fd_config.parallel_config.tensor_parallel_rank = 1 + fd_config.parallel_config.tensor_parallel_size = 4 + manager.fd_config = fd_config + manager.load_config = fd_config.load_config + manager.parallel_config = fd_config.parallel_config + manager.local_rank = 5 + manager.nranks = 8 + manager.rdma_handle = None + manager.model_list = [_FakeModel()] + manager.state_dict = {} + manager.use_gdr_checkpoint_transfer = True + return manager + + +class _FakeRole(Enum): + TRAINER = "trainer" + INFERENCE = "inference" + + +class _FakePhase1Backend(Enum): + GPU_DIRECT = "gpu_direct" + MOONCAKE = "mooncake" + IPC = "ipc" + + +@dataclass +class _FakeTransferConfig: + role: object + global_rank: int + group_size: int = 1 + phase1_backend: object = _FakePhase1Backend.GPU_DIRECT + phase2_backend: object = None + phase2_fan_out: int = 4 + bucket_size_mb: int = 512 + num_buffers: int = 2 + redis_host: str = "127.0.0.1" + redis_port: int = 6379 + discover_timeout_s: float = 60.0 + redis_ttl_s: int = 60 + recv_bucket_timeout_s: float = 60.0 + session_total_timeout_s: float = 600.0 + device: str = None + log_level: str = None + log_file: str = None + perf_log_file: str = None + materialize_tensors: bool = True + qsize: int = 3 + gpu_id: int = -1 + + def __post_init__(self): + self.kwargs = dict(self.__dict__) + self.kwargs.pop("kwargs", None) + + +def _patch_gdr_checkpoint_transfer(fake_checkpoint_transfer): + class FakeCheckpointTransferWithLifecycle(fake_checkpoint_transfer): + async def initialize(self): + self.initialized = True + + async def cleanup(self): + self.cleaned = True + + fake_config_module = types.SimpleNamespace( + Role=_FakeRole, + TransferConfig=_FakeTransferConfig, + Phase1Backend=_FakePhase1Backend, + ) + fake_transfer_module = types.SimpleNamespace(CheckpointTransfer=FakeCheckpointTransferWithLifecycle) + return patch.dict( + sys.modules, + { + "checkpoint_transfer.config": fake_config_module, + "checkpoint_transfer.transfer": fake_transfer_module, + }, + ) + + +class TestDynamicWeightGDR(unittest.TestCase): + def test_update_weights_by_gdr_gdr_mode(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + self.step_id = step_id + self.output_framework = output_framework + yield "model.layers.0.weight", object() + + manager = _make_manager() + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + result = manager.update_weights_by_gdr(version="step-1") + + self.assertEqual(result["version"], "step-1") + self.assertEqual(result["update_count"], 1) + self.assertIn("total_cost", result) + self.assertEqual(manager.model_list[0].loaded[0][0], "model.layers.0.weight") + self.assertTrue(created[0].initialized) + self.assertTrue(created[0].cleaned) + self.assertEqual(created[0].step_id, "step-1") + self.assertEqual(created[0].output_framework, "paddle") + self.assertEqual(created[0].config.kwargs["role"], _FakeRole.INFERENCE) + self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.GPU_DIRECT) + self.assertEqual(created[0].config.kwargs["global_rank"], 5) + self.assertEqual(created[0].config.kwargs["group_size"], 8) + self.assertNotIn("backend", created[0].config.kwargs) + self.assertNotIn("output_framework", created[0].config.kwargs) + + def test_update_weights_by_gdr_ipc_mode(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + self.step_id = step_id + yield "model.layers.0.weight", object() + + manager = _make_manager( + rsync_config={"redis_host": "10.0.0.1", "redis_port": 6379}, + load_strategy="ipc", + ) + + with ( + _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer), + patch.dict("os.environ", {"FLAGS_selected_gpus": "3"}), + ): + result = manager.update_weights_by_gdr() + + self.assertEqual(result["version"], "0") + self.assertEqual(created[0].step_id, "0") + self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.IPC) + self.assertEqual(created[0].config.kwargs["global_rank"], 3) + self.assertEqual(created[0].config.kwargs["qsize"], 2) + + def test_gdr_checkpoint_transfer_receive_exception_propagates(self): + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.layers.0.weight", object() + raise RuntimeError("receive failed") + + class IncrementalModel(_FakeModel): + def load_weights(self, weights_iterator): + for item in weights_iterator: + self.loaded.append(item) + + manager = _make_manager() + manager.model_list = [IncrementalModel()] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + with self.assertRaisesRegex(RuntimeError, "receive failed"): + manager.update_weights_by_gdr(version="step-error") + + def test_gdr_checkpoint_transfer_refreshes_state_dict_after_model_loader(self): + loaded_param = object() + + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.weight", loaded_param + + class RefreshingModel(_FakeModel): + def load_weights(self, weights_iterator): + super().load_weights(weights_iterator) + self.params["model.weight"] = loaded_param + + manager = _make_manager() + manager.model_list = [RefreshingModel()] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + manager.update_weights_by_gdr(version="step-refresh") + + self.assertIs(manager.state_dict["model.weight"], loaded_param) + + def test_gdr_checkpoint_transfer_caches_mtp_subset_for_auxiliary_model(self): + objects = [object() for _ in range(4)] + + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.layers.0.self_attn.q_proj.weight", objects[0] + yield "model.layers.2.self_attn.q_proj.weight", objects[1] + yield "model.layers.20.self_attn.q_proj.weight", objects[2] + yield "ernie.mtp_linear_proj.0.weight", objects[3] + + manager = _make_manager() + main_model = _FakeModel() + mtp_model = _FakeMTPModel(mtp_start_layer_idx=2, num_mtp_layers=1) + manager.model_list = [main_model, mtp_model] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + result = manager.update_weights_by_gdr(version="step-5") + + self.assertEqual(result["update_count"], 4) + self.assertEqual(result["mtp_cache_count"], 2) + self.assertEqual( + [name for name, _ in main_model.loaded], + [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.2.self_attn.q_proj.weight", + "model.layers.20.self_attn.q_proj.weight", + "ernie.mtp_linear_proj.0.weight", + ], + ) + self.assertEqual( + [name for name, _ in mtp_model.loaded], + [ + "model.layers.2.self_attn.q_proj.weight", + "ernie.mtp_linear_proj.0.weight", + ], + ) + + def test_gdr_checkpoint_transfer_flushes_mtp_subset_by_chunk_limit(self): + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.layers.2.self_attn.q_proj.weight", object() + yield "ernie.mtp_linear_proj.0.weight", object() + yield "model.layers.2.self_attn.o_proj.weight", object() + + class ChunkRecordingMTPModel(_FakeMTPModel): + def __init__(self): + super().__init__(mtp_start_layer_idx=2, num_mtp_layers=1) + self.load_calls = [] + + def load_weights(self, weights_iterator): + chunk = list(weights_iterator) + self.load_calls.append([name for name, _ in chunk]) + self.loaded.extend(chunk) + + manager = _make_manager( + { + "backend": "mooncake", + "output_framework": "paddle", + "gdr_mtp_chunk_size": 2, + } + ) + main_model = _FakeModel() + mtp_model = ChunkRecordingMTPModel() + manager.model_list = [main_model, mtp_model] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + result = manager.update_weights_by_gdr(version="step-8") + + self.assertEqual(result["mtp_cache_count"], 3) + self.assertEqual( + mtp_model.load_calls, + [ + [ + "model.layers.2.self_attn.q_proj.weight", + "ernie.mtp_linear_proj.0.weight", + ], + ["model.layers.2.self_attn.o_proj.weight"], + ], + ) + + def test_gdr_checkpoint_transfer_multi_model_requires_mtp_subset(self): + class FakeCheckpointTransfer: + def __init__(self, config): + pass + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "model.layers.0.self_attn.q_proj.weight", object() + + manager = _make_manager() + manager.model_list = [_FakeModel(), _FakeMTPModel(mtp_start_layer_idx=2, num_mtp_layers=1)] + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + with self.assertRaisesRegex(ValueError, "No MTP weights"): + manager.update_weights_by_gdr(version="step-5") + + def test_gdr_checkpoint_transfer_config_not_forwarded_to_transfer_config(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "w1", object() + + manager = _make_manager( + { + "backend": "mooncake", + "output_framework": "paddle", + } + ) + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + manager.update_weights_by_gdr(version="step-6") + + self.assertNotIn("gpu_direct", created[0].config.kwargs) + self.assertNotIn("output_framework", created[0].config.kwargs) + self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.GPU_DIRECT) + + def test_gdr_checkpoint_transfer_computes_global_rank_from_node_index(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "w1", object() + + manager = _make_manager( + { + "index": 1, + "backend": "mooncake", + "output_framework": "paddle", + "group_size": 16, + } + ) + manager.local_rank = 5 + manager.nranks = 8 + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + manager.update_weights_by_gdr(version="step-index") + + self.assertEqual(created[0].config.kwargs["global_rank"], 13) + self.assertEqual(created[0].config.kwargs["group_size"], 16) + self.assertNotIn("index", created[0].config.kwargs) + + def test_gdr_checkpoint_transfer_config_deep_copied_before_forwarding(self): + created = [] + + class FakeCheckpointTransfer: + def __init__(self, config): + self.config = config + created.append(self) + + def receive_weights_sync(self, step_id, output_framework="paddle"): + yield "w1", object() + + rsync_config = { + "backend": "mooncake", + "output_framework": "paddle", + "device_name": "mlx5_0", + } + manager = _make_manager(rsync_config) + + with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): + manager.update_weights_by_gdr(version="step-7") + + self.assertEqual(created[0].config.kwargs["device"], "mlx5_0") + self.assertEqual(rsync_config["device_name"], "mlx5_0") + + def test_finalize_update_uses_worker_queue_port_status_suffix(self): + module = _load_dynamic_weight_manager_module() + manager = _make_manager() + manager.first_load = False + manager.rank = 0 + manager.parallel_config.tensor_parallel_size = 1 + manager.parallel_config.enable_expert_parallel = False + manager.parallel_config.local_engine_worker_queue_port = 60572 + manager._verify_parameters = MagicMock() + + class FakeArray: + shape = (1,) + dtype = "int32" + nbytes = 4 + + class FakeValue: + def __init__(self): + self.writes = {} + + def __setitem__(self, key, value): + self.writes[key] = value + + fake_value = FakeValue() + with ( + patch.object(module.np, "int32", "int32", create=True), + patch.object(module.np, "zeros", return_value=FakeArray(), create=True), + patch.object(module.np, "ndarray", return_value=fake_value, create=True), + patch.object(module, "SharedMemory") as fake_shared_memory, + ): + manager.finalize_update() + + fake_shared_memory.assert_called_once_with(create=False, size=4, name="model_weights_status.60572") + self.assertEqual(fake_value.writes[0], module.ModelWeightsStatus.NORMAL) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/router/test_router.py b/tests/router/test_router.py index aa5be52f2f1..3970cd8b46d 100644 --- a/tests/router/test_router.py +++ b/tests/router/test_router.py @@ -20,6 +20,7 @@ We mock it at the network boundary to test Router's registration and selection logic. """ +import asyncio import unittest from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch @@ -28,7 +29,14 @@ def _make_args(**kwargs): - defaults = {"host": "0.0.0.0", "port": 9000, "splitwise": False, "request_timeout_secs": 30} + defaults = { + "host": "0.0.0.0", + "port": 9000, + "splitwise": False, + "request_timeout_secs": 30, + "preempt_retry_count": 3, + "preempt_retry_exclude_decode": False, + } defaults.update(kwargs) return SimpleNamespace(**defaults) @@ -171,7 +179,7 @@ async def _coro(): @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health): - """P and D both receive the request, but only D results are aggregated.""" + """Router returns 200 immediately and forwards to all (P + D) servers in background.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -182,24 +190,8 @@ async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 0}], "not_found": []}, - } - ) decode_resp = AsyncMock() decode_resp.status = 200 - decode_resp.json = AsyncMock( - return_value={ - "request_id": "control-d", - "status": "success", - "error_message": None, - "result": {"aborted": [{"request_id": "req-1_0", "output_token_count": 15}], "not_found": []}, - } - ) mock_session = self._make_mock_session([prefill_resp, decode_resp]) mock_request = AsyncMock() @@ -207,18 +199,17 @@ async def test_abort_broadcasts_to_all_but_returns_decode_only(self, mock_health with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + # Give the background task a chance to run + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(len(body["result"]["aborted"]), 1) - self.assertEqual(body["result"]["aborted"][0]["output_token_count"], 15) - self.assertEqual(body["status"], "success") + self.assertEqual(resp.status_code, 200) + # Forwarded to both prefill + decode self.assertEqual(mock_session.post.call_count, 2) @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) - async def test_abort_decode_error_returns_error_status(self, mock_health): - """When D node returns a non-200 status, status should be 'error'.""" + async def test_abort_returns_200_even_when_decode_errors(self, mock_health): + """Router fire-and-forgets: still returns 200 when D returns non-200.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -229,14 +220,6 @@ async def test_abort_decode_error_returns_error_status(self, mock_health): prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [], "not_found": []}, - } - ) decode_resp = AsyncMock() decode_resp.status = 500 @@ -246,16 +229,14 @@ async def test_abort_decode_error_returns_error_status(self, mock_health): with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(body["status"], "error") - self.assertIsNotNone(body["error_message"]) + self.assertEqual(resp.status_code, 200) @patch("fastdeploy.router.router.check_service_health_async", new_callable=AsyncMock, return_value=True) - async def test_abort_decode_exception_returns_error(self, mock_health): - """When D node connection fails (exception), error should be captured.""" + async def test_abort_returns_200_when_decode_raises(self, mock_health): + """Router fire-and-forgets: still returns 200 when a downstream raises.""" from fastdeploy.router.router import abort_requests as abort_fn from fastdeploy.router.router import app @@ -266,30 +247,20 @@ async def test_abort_decode_exception_returns_error(self, mock_health): prefill_resp = AsyncMock() prefill_resp.status = 200 - prefill_resp.json = AsyncMock( - return_value={ - "request_id": "control-p", - "status": "success", - "error_message": None, - "result": {"aborted": [], "not_found": []}, - } - ) - - # D node raises exception — but asyncio.gather(return_exceptions=True) captures it - # So we pass the exception as a response directly + mock_session = self._make_mock_session([prefill_resp, prefill_resp]) # placeholder call_idx = [0] def post_with_exception(*args, **kwargs): call_idx[0] += 1 if call_idx[0] == 1: - # prefill: normal + async def _coro(): return prefill_resp return _coro() else: - # decode: raise (gather with return_exceptions=True will catch) + async def _coro_err(): raise ConnectionError("refused") @@ -301,12 +272,10 @@ async def _coro_err(): with patch("fastdeploy.router.router.aiohttp.ClientSession", return_value=mock_session): resp = await abort_fn(mock_request) + await asyncio.sleep(0) + await asyncio.sleep(0) - import json - - body = json.loads(resp.body) - self.assertEqual(body["status"], "error") - self.assertIn("refused", body["error_message"]) + self.assertEqual(resp.status_code, 200) if __name__ == "__main__": diff --git a/tests/scheduler/test_chunked_prefill_determinism.py b/tests/scheduler/test_chunked_prefill_determinism.py index 1a0f786f3d1..17b466a014c 100644 --- a/tests/scheduler/test_chunked_prefill_determinism.py +++ b/tests/scheduler/test_chunked_prefill_determinism.py @@ -78,6 +78,7 @@ def __init__(self): self.cache_config = CacheConfig() self.parallel_config = ParallelConfig() self.speculative_config = SpeculativeConfig() + self.enable_mm_runtime = self.model_config.enable_mm # --------------------------------------------------------------------------- @@ -168,6 +169,7 @@ def _create_resource_manager(self, config): def _create_mm_resource_manager(self): config = StubConfig() config.model_config.enable_mm = True + config.enable_mm_runtime = config.model_config.enable_mm return self._create_resource_manager(config) # ==================== 1. Deterministic disabled ==================== diff --git a/tests/scheduler/test_dp_scheduler.py b/tests/scheduler/test_dp_scheduler.py index 0e42c4491f3..a5f9cfa8380 100644 --- a/tests/scheduler/test_dp_scheduler.py +++ b/tests/scheduler/test_dp_scheduler.py @@ -411,6 +411,32 @@ def test_recycle_expired_requests(self, mock_time): self.assertEqual(scheduler.ids, ["fresh_req"]) self.assertEqual(scheduler.ids_read_cursor, 1) + def test_get_requests_insufficient_resources(self): + """Test getting requests when resources are insufficient.""" + mock_logger.reset_mock() + + # Test with insufficient blocks - mock the condition variable to avoid threading issues + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=5, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=1 + ) + + self.assertEqual(requests, []) + # The logger should have been called for insufficient resources + self.assertTrue(mock_logger.debug.called) + # Check the message contains expected content + call_args = mock_logger.debug.call_args[0][0] + self.assertIn("insufficient", call_args.lower()) + + def test_get_requests_insufficient_batch(self): + """Test getting requests when batch size is insufficient.""" + with patch.object(self.scheduler, "requests_not_empty"): + requests = self.scheduler.get_requests( + available_blocks=20, block_size=16, reserved_output_blocks=10, max_num_batched_tokens=1024, batch=0 + ) + + self.assertEqual(requests, []) + @patch("time.time") @patch.object(dp_scheduler_module, "envs") def test_get_requests_no_requests_available(self, mock_envs, mock_time): diff --git a/tests/splitwise/test_internal_adapter_utils.py b/tests/splitwise/test_internal_adapter_utils.py index f8f22215c02..4d772789848 100644 --- a/tests/splitwise/test_internal_adapter_utils.py +++ b/tests/splitwise/test_internal_adapter_utils.py @@ -25,9 +25,6 @@ class DummyEngine: """Dummy Engine class to simulate the actual Engine for testing.""" class ResourceManager: - def __init__(self): - self.waiting = [] - def available_batch(self): return 4 diff --git a/tests/splitwise/test_splitwise_connector.py b/tests/splitwise/test_splitwise_connector.py index 610cfd9246d..cc39a52cd76 100644 --- a/tests/splitwise/test_splitwise_connector.py +++ b/tests/splitwise/test_splitwise_connector.py @@ -17,6 +17,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from threading import Lock from typing import Any, Dict, List from unittest.mock import Mock, patch @@ -24,13 +25,8 @@ import pytest import zmq -if not hasattr(paddle, "compat"): - - class _CompatStub: - def enable_torch_proxy(self, scope=None): - return None - - paddle.compat = _CompatStub() +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None from fastdeploy import envs from fastdeploy.engine.request import Request, RequestMetrics, RequestOutput @@ -89,6 +85,10 @@ def _build_connector() -> SplitwiseConnector: connector = SplitwiseConnector(cfg=DummyCfg(), worker_queue=DummyWorkerQueue(), resource_manager=None) if not hasattr(connector, "push_sockets"): connector.push_sockets = {} + if not hasattr(connector, "_push_socket_locks"): + connector._push_socket_locks = {} + if not hasattr(connector, "_push_sockets_meta_lock"): + connector._push_sockets_meta_lock = Lock() return connector @@ -208,15 +208,6 @@ def test_send_cache_info_to_prefill_groups_by_addr_and_skips_error(): "block_tables": [1, 2, 3], }, ), - DummyTask( - request_id="req-err", - disaggregate_info={ - "prefill_ip": "10.0.0.2", - "prefill_connector_port": 9002, - "block_tables": [9], - }, - error_msg="failed", - ), ] connector.send_cache_info_to_prefill(tasks) @@ -262,9 +253,11 @@ def test_get_push_socket_reuses_existing_and_handles_zmq_error(): open_socket = Mock() open_socket.closed = False connector.push_sockets["127.0.0.1:8000"] = open_socket + connector._push_socket_locks["127.0.0.1:8000"] = Lock() - same_socket = connector._get_push_socket("127.0.0.1:8000") + same_socket, same_lock = connector._get_push_socket("127.0.0.1:8000") assert same_socket is open_socket + assert same_lock is connector._push_socket_locks["127.0.0.1:8000"] connector.zmq_ctx = Mock() connector.zmq_ctx.socket.side_effect = zmq.ZMQError("boom") @@ -279,9 +272,10 @@ def test_get_push_socket_creates_and_configures_socket(): new_socket.closed = False connector.zmq_ctx.socket.return_value = new_socket - socket = connector._get_push_socket("127.0.0.1:7000") + socket, lock = connector._get_push_socket("127.0.0.1:7000") assert socket is new_socket + assert lock is connector._push_socket_locks["127.0.0.1:7000"] new_socket.connect.assert_called_once_with("tcp://127.0.0.1:7000") assert connector.push_sockets["127.0.0.1:7000"] is new_socket @@ -289,7 +283,8 @@ def test_get_push_socket_creates_and_configures_socket(): def test_send_message_serializes_and_sends_payload(): connector = _build_connector() mock_socket = Mock() - connector._get_push_socket = Mock(return_value=mock_socket) + mock_socket.closed = False + connector._get_push_socket = Mock(return_value=(mock_socket, Lock())) request = Request( request_id="req-send", prompt=None, @@ -326,14 +321,17 @@ def test_send_message_handles_missing_addr_and_errors(): connector._send_message("127.0.0.1:7000", "prefill", []) failing_socket = Mock() + failing_socket.closed = False failing_socket.send_multipart.side_effect = zmq.Again() - connector._get_push_socket = Mock(return_value=failing_socket) + connector._get_push_socket = Mock(return_value=(failing_socket, Lock())) connector._send_message("127.0.0.1:7001", "prefill", []) crash_socket = Mock() + crash_socket.closed = False crash_socket.send_multipart.side_effect = RuntimeError("boom") - connector._get_push_socket = Mock(return_value=crash_socket) + connector._get_push_socket = Mock(return_value=(crash_socket, Lock())) connector.push_sockets["127.0.0.1:7002"] = crash_socket + connector._push_socket_locks["127.0.0.1:7002"] = Lock() connector._send_message("127.0.0.1:7002", "prefill", []) assert "127.0.0.1:7002" not in connector.push_sockets diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index 240cf702ed7..4f55ca46472 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -138,7 +138,7 @@ def test_fdconfig_init_cache(self): model_config=model_config, test_mode=True, ) - fd_config.init_cache_info() + fd_config.init_pd_info() assert fd_config.register_info is not None def test_fdconfig_postprocess_ports(self): diff --git a/tests/utils/test_find_free_ports.py b/tests/utils/test_find_free_ports.py new file mode 100644 index 00000000000..3ffe272443e --- /dev/null +++ b/tests/utils/test_find_free_ports.py @@ -0,0 +1,212 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from unittest.mock import patch + +import pytest + +from fastdeploy.utils import find_free_ports + + +class TestFindFreePorts: + """Unit tests for find_free_ports function.""" + + def test_find_single_free_port_success(self): + """Test finding a single free port successfully.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(20000, 20100), num_ports=1) + assert len(ports) == 1 + assert 20000 <= ports[0] <= 20100 + + def test_find_multiple_free_ports_success(self): + """Test finding multiple free ports successfully.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(20000, 20100), num_ports=5) + assert len(ports) == 5 + for port in ports: + assert 20000 <= port <= 20100 + + def test_find_ports_with_custom_host(self): + """Test finding ports with a custom host.""" + with patch("fastdeploy.utils.is_port_available", return_value=True) as mock_avail: + ports = find_free_ports(port_range=(30000, 30010), num_ports=2, host="127.0.0.1") + assert len(ports) == 2 + # Verify is_port_available was called with the custom host + for call in mock_avail.call_args_list: + assert call[0][0] == "127.0.0.1" + + def test_find_all_ports_in_range(self): + """Test finding all ports in a small range.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(40000, 40002), num_ports=3) + assert len(ports) == 3 + # All ports should be from the range + expected_ports = {40000, 40001, 40002} + assert set(ports) == expected_ports + + def test_invalid_port_range_start_negative(self): + """Test ValueError when port range start is negative.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(-1, 1000)) + + def test_invalid_port_range_end_exceeds_max(self): + """Test ValueError when port range end exceeds 65535.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(1000, 65536)) + + def test_invalid_port_range_start_greater_than_end(self): + """Test ValueError when port range start is greater than end.""" + with pytest.raises(ValueError, match="Invalid port range"): + find_free_ports(port_range=(10000, 9000)) + + def test_invalid_port_range_boundary_values(self): + """Test port range boundary at exactly 0 and 65535.""" + # Valid: start = 0 + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(0, 100), num_ports=1) + assert len(ports) == 1 + + # Valid: end = 65535 + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(65530, 65535), num_ports=1) + assert len(ports) == 1 + + def test_num_ports_zero_raises_error(self): + """Test ValueError when num_ports is zero.""" + with pytest.raises(ValueError, match="num_ports must be a positive integer"): + find_free_ports(port_range=(20000, 30000), num_ports=0) + + def test_num_ports_negative_raises_error(self): + """Test ValueError when num_ports is negative.""" + with pytest.raises(ValueError, match="num_ports must be a positive integer"): + find_free_ports(port_range=(20000, 30000), num_ports=-1) + + def test_num_ports_larger_than_range_size(self): + """Test ValueError when num_ports exceeds the range size.""" + # Range has only 5 ports (100-104), but requesting 6 + with pytest.raises(ValueError, match="num_ports is larger than range size"): + find_free_ports(port_range=(100, 104), num_ports=6) + + def test_not_enough_free_ports_raises_runtime_error(self): + """Test RuntimeError when not enough free ports are available.""" + # Mock to return False for all ports + with patch("fastdeploy.utils.is_port_available", return_value=False): + with pytest.raises(RuntimeError, match="Only found 0 free ports"): + find_free_ports(port_range=(20000, 20010), num_ports=3) + + def test_partial_free_ports_raises_runtime_error(self): + """Test RuntimeError when only some ports are free.""" + call_count = [0] + + def mock_availability(host, port): + # Only first 2 ports are available + call_count[0] += 1 + return call_count[0] <= 2 + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with pytest.raises(RuntimeError, match="Only found 2 free ports"): + find_free_ports(port_range=(20000, 20005), num_ports=5) + + def test_random_start_offset(self): + """Test that port scanning starts from a random offset.""" + # Track the order of ports checked + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + ports = find_free_ports(port_range=(100, 105), num_ports=3) + + # With offset 0, ports should be checked in order + assert checked_ports[:3] == [100, 101, 102] + assert ports == [100, 101, 102] + + def test_random_start_offset_non_zero(self): + """Test port scanning with non-zero random offset.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + # With offset 2, scanning starts from port 102 + with patch("fastdeploy.utils.random.randint", return_value=2): + ports = find_free_ports(port_range=(100, 105), num_ports=3) + + # With offset 2, ports are rotated: [102, 103, 104, 105, 100, 101] + assert checked_ports[:3] == [102, 103, 104] + assert ports == [102, 103, 104] + + def test_single_port_range(self): + """Test finding port from a single-port range.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports(port_range=(12345, 12345), num_ports=1) + assert ports == [12345] + + def test_single_port_range_not_available(self): + """Test RuntimeError when the single port in range is not available.""" + with patch("fastdeploy.utils.is_port_available", return_value=False): + with pytest.raises(RuntimeError, match="Only found 0 free ports"): + find_free_ports(port_range=(12345, 12345), num_ports=1) + + def test_default_parameters(self): + """Test function with default parameters.""" + with patch("fastdeploy.utils.is_port_available", return_value=True): + ports = find_free_ports() + assert len(ports) == 1 + assert 8000 <= ports[0] <= 65535 + + def test_stops_early_when_enough_ports_found(self): + """Test that scanning stops as soon as enough ports are found.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + return True + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + # Range has 100 ports but we only need 2 + ports = find_free_ports(port_range=(20000, 20099), num_ports=2) + + # Should only check 2 ports, not all 100 + assert len(checked_ports) == 2 + assert len(ports) == 2 + + def test_skips_unavailable_ports(self): + """Test that unavailable ports are skipped.""" + checked_ports = [] + + def mock_availability(host, port): + checked_ports.append(port) + # Only odd ports are available + return port % 2 == 1 + + with patch("fastdeploy.utils.is_port_available", side_effect=mock_availability): + with patch("fastdeploy.utils.random.randint", return_value=0): + ports = find_free_ports(port_range=(100, 110), num_ports=3) + + # Should find 3 odd ports: 101, 103, 105 + assert len(ports) == 3 + assert all(p % 2 == 1 for p in ports) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/v1/cache_manager/test_prefix_cache.py b/tests/v1/cache_manager/test_prefix_cache.py index b3393500173..0a5eb669582 100644 --- a/tests/v1/cache_manager/test_prefix_cache.py +++ b/tests/v1/cache_manager/test_prefix_cache.py @@ -31,7 +31,6 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over args = asdict(engine_args) cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=enable_mm, max_model_len=4196) - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.architectures = ["test_model"] model_cfg.mm_max_tokens_per_item = None @@ -46,7 +45,7 @@ def make_prefix_cache_manager(max_num_seqs, enable_mm=False, num_gpu_blocks_over cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, ) return PrefixCacheManager(config=fd_config, tensor_parallel_size=8, splitwise_role="mixed") diff --git a/tests/v1/test_resource_manager_v1.py b/tests/v1/test_resource_manager_v1.py index 7cee36cd060..6c51adb63a5 100644 --- a/tests/v1/test_resource_manager_v1.py +++ b/tests/v1/test_resource_manager_v1.py @@ -24,9 +24,10 @@ import numpy as np import paddle -if not hasattr(paddle, "compat"): - paddle.compat = SimpleNamespace(enable_torch_proxy=lambda scope: None) +if not hasattr(paddle, "enable_compat"): + paddle.enable_compat = lambda scope=None: None +from fastdeploy import envs from fastdeploy.config import CacheConfig, FDConfig, ParallelConfig, SchedulerConfig from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.request import ( @@ -36,6 +37,7 @@ RequestMetrics, RequestOutput, RequestStatus, + RequestType, ) from fastdeploy.engine.sched.resource_manager_v1 import ( ResourceManagerV1, @@ -138,7 +140,6 @@ def setUp(self): cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=True) # Enable multimodal for feature testing - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 3200 model_cfg.architectures = ["test_model"] @@ -155,7 +156,7 @@ def setUp(self): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, ) self.manager = ResourceManagerV1( @@ -570,6 +571,26 @@ def test_preallocate_resource_in_p_and_d(self): self.assertEqual(request_d.num_computed_tokens, request_d.need_prefill_tokens) self.assertEqual(request_d.disaggregate_info["block_tables"], [4, 5]) + def test_decode_role_prefill_task_logs_decode_bootstrap_batch(self): + manager = _build_manager(splitwise_role="decode", enable_prefix_caching=False) + _register_manager_cleanup(self, manager) + manager.cache_manager = MagicMock() + manager.cache_manager.num_gpu_blocks = 8 + manager.cache_manager.gpu_free_block_list = [0, 1, 2, 3] + manager.scheduler_metrics_logger = MagicMock() + + request = _make_request(prompt_token_ids=[1, 2, 3, 4]) + request.task_type = RequestType.PREFILL + request.prefill_start_index = 4 + request.prefill_end_index = 5 + batch_request = [request] + + with patch.object(envs, "FD_CONSOLE_SCHEDULER_METRICS", True): + manager._log_console_scheduler_metrics(batch_request) + + manager.scheduler_metrics_logger.log_decode_bootstrap_batch.assert_called_once() + manager.scheduler_metrics_logger.log_prefill_batch.assert_not_called() + def test_prefilled_request_flow_and_resource_check(self): manager = _build_manager(splitwise_role="decode", speculative_method="mtp") _register_manager_cleanup(self, manager) @@ -580,6 +601,7 @@ def test_prefilled_request_flow_and_resource_check(self): self.assertTrue(manager.has_resource_for_prefilled_req("prefilled")) request = _make_request(request_id="req-prefilled") + request.idx = 0 request.metrics.decode_recv_req_time = 1.0 request.metrics.decode_preallocate_req_time = 2.0 manager.requests[request.request_id] = request @@ -650,7 +672,7 @@ def test_schedule_decode_and_waiting_prefill(self): decode_request = _make_request(request_id="req-decode", prompt_token_ids=[1, 2]) decode_request.idx = 0 - decode_request.status = RequestStatus.RUNNING + decode_request.status = RequestStatus.RUNNING_DECODE decode_request.num_computed_tokens = 2 decode_request.output_token_ids = [99] decode_request.block_tables = [1] @@ -665,7 +687,7 @@ def test_schedule_decode_and_waiting_prefill(self): self.assertGreaterEqual(len(scheduled_reqs), 2) self.assertEqual(error_reqs, []) self.assertIn(decode_request.request_id, manager.using_extend_tables_req_id) - self.assertEqual(waiting_request.status, RequestStatus.RUNNING) + self.assertEqual(waiting_request.status, RequestStatus.RUNNING_PREFILL) def test_trigger_preempt_records_tasks(self): manager = _build_manager() @@ -678,6 +700,7 @@ def test_trigger_preempt_records_tasks(self): preempted_req = _make_request(request_id="req-preempted") preempted_req.idx = 0 preempted_req.use_extend_tables = False + preempted_req.status = RequestStatus.RUNNING_DECODE request = _make_request(request_id="req-target") request.idx = 1 manager.running = [request, preempted_req] diff --git a/tests/v1/test_schedule_output.py b/tests/v1/test_schedule_output.py index 3175b087e2e..db30b15f48b 100644 --- a/tests/v1/test_schedule_output.py +++ b/tests/v1/test_schedule_output.py @@ -29,7 +29,6 @@ def test_normal_schedule(): args = asdict(engine_args) cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=False) - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.mm_max_tokens_per_item = None @@ -41,7 +40,7 @@ def test_normal_schedule(): model_config=model_cfg, cache_config=cache_cfg, parallel_config=parallel_cfg, - speculative_config=speculative_cfg, + speculative_config=None, graph_opt_config=graph_opt_cfg, scheduler_config=scheduler_cfg, ) @@ -95,7 +94,6 @@ def test_preempted_request(): args = asdict(engine_args) cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=False) - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.mm_max_tokens_per_item = None @@ -108,7 +106,7 @@ def test_preempted_request(): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, ) resource_manager_v1 = ResourceManagerV1( @@ -162,7 +160,6 @@ def test_caching_output(): args = asdict(engine_args) cache_cfg = CacheConfig(args) model_cfg = SimpleNamespace(enable_mm=False) - speculative_cfg = SimpleNamespace(method=None) model_cfg.print = print model_cfg.max_model_len = 5120 model_cfg.mm_max_tokens_per_item = None @@ -175,7 +172,7 @@ def test_caching_output(): cache_config=cache_cfg, parallel_config=parallel_cfg, graph_opt_config=graph_opt_cfg, - speculative_config=speculative_cfg, + speculative_config=None, scheduler_config=scheduler_cfg, ) resource_manager_v1 = ResourceManagerV1( diff --git a/tests/worker/test_gpu_model_runner.py b/tests/worker/test_gpu_model_runner.py index 3a02475b5ae..008eafcaf62 100644 --- a/tests/worker/test_gpu_model_runner.py +++ b/tests/worker/test_gpu_model_runner.py @@ -14,11 +14,12 @@ import unittest from dataclasses import dataclass -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import numpy as np import paddle +from fastdeploy.config import PREEMPTED_TOKEN_ID from fastdeploy.engine.request import ImagePosition from fastdeploy.spec_decode import SpecMethod from fastdeploy.worker.gpu_model_runner import GPUModelRunner @@ -487,7 +488,7 @@ def _make_runner(self): runner.local_rank = 0 runner.device_id = 1 runner.num_gpu_blocks = 8 - runner.model = Mock(clear_grpah_opt_backend=Mock()) + runner.model = Mock(clear_graph_opt_backend=Mock()) runner.clear_cache = Mock() runner.initialize_kv_cache = Mock() runner.capture_model = Mock() @@ -523,7 +524,7 @@ def test_sleep_offloads_weight_and_cache(self, mock_empty_cache, mock_print_memo runner.sleep("weight,kv_cache") - runner.model.clear_grpah_opt_backend.assert_called_once() + runner.model.clear_graph_opt_backend.assert_called_once() runner.dynamic_weight_manager.clear_deepep_buffer.assert_called_once() runner.dynamic_weight_manager.clear_model_weight.assert_called_once() runner.dynamic_weight_manager.clear_communication_group.assert_called_once() @@ -591,5 +592,343 @@ def test_wakeup_kvcache_is_idempotent(self, mock_print_memory): mock_print_memory.assert_not_called() +class TestMakePreemptedBatchOutput(unittest.TestCase): + def _make_runner(self, speculative_decoding=False, enable_logprob=False): + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.speculative_decoding = speculative_decoding + runner.enable_logprob = enable_logprob + runner.parallel_config = Mock(msg_queue_id=0, tensor_parallel_rank=0, use_ep=False) + + class _ShareInputs(dict): + enable_pd_reorder = False + + share_inputs = _ShareInputs() + share_inputs["preempted_idx"] = paddle.to_tensor( + [[0], [0], [0], [1], [0], [0], [1], [0], [0], [0]], dtype="int32" + ) + share_inputs["sampled_token_ids"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["index_to_batch_id"] = {i: i for i in range(10)} + share_inputs["next_tokens"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["stop_flags"] = paddle.zeros([10, 1], dtype="bool") + share_inputs["step_idx"] = 0 + share_inputs["max_dec_len"] = 16 + share_inputs["seq_lens_this_time"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["eos_token_id"] = paddle.zeros([1], dtype="int64") + share_inputs["not_need_stop"] = False + share_inputs["not_need_stop_device"] = paddle.zeros([1], dtype="bool") + share_inputs["input_ids"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["seq_lens_encoder"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["seq_lens_decoder"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["is_block_step"] = paddle.zeros([10, 1], dtype="bool") + share_inputs["token_ids_all"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["stop_seqs"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["stop_seqs_len"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["min_dec_len"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["prompt_lens"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["mask_rollback"] = paddle.zeros([10, 1], dtype="bool") + share_inputs["accept_tokens_cpu"] = paddle.full([10, 1], fill_value=-1, dtype="int64") + share_inputs["accept_num_cpu"] = paddle.full([10, 1], fill_value=-1, dtype="int32") + share_inputs["seq_lens_decoder_cpu"] = paddle.full([10, 1], fill_value=-1, dtype="int32") + share_inputs["prompt_lens_cpu"] = paddle.full([10, 1], fill_value=-1, dtype="int32") + share_inputs["draft_tokens"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["actual_draft_token_num"] = paddle.zeros([10, 1], dtype="int32") + share_inputs["accept_tokens"] = paddle.zeros([10, 1], dtype="int64") + share_inputs["accept_num"] = paddle.zeros([10, 1], dtype="int32") + runner.share_inputs = share_inputs + return runner + + def test_make_preempted_batch_output_emits_sparse_preempt_mask(self): + runner = self._make_runner() + + model_output_data, sampler_output = runner._make_preempted_batch_output() + + expected = [-1, -1, -1, PREEMPTED_TOKEN_ID, -1, -1, PREEMPTED_TOKEN_ID] + self.assertEqual(sampler_output.sampled_token_ids.shape, [7, 1]) + self.assertEqual(sampler_output.sampled_token_ids.numpy().reshape([-1]).tolist(), expected) + self.assertEqual(runner.share_inputs["sampled_token_ids"][:7].numpy().reshape([-1]).tolist(), expected) + self.assertEqual(model_output_data.index_to_batch_id, {i: i for i in range(7)}) + + def test_make_preempted_batch_output_speculative_logprob(self): + runner = self._make_runner(speculative_decoding=True, enable_logprob=True) + runner.share_inputs["seq_lens_decoder"][:7] = paddle.arange(7, dtype="int32").reshape([7, 1]) + runner.share_inputs["prompt_lens"][:7] = paddle.arange(10, 17, dtype="int32").reshape([7, 1]) + + model_output_data, sampler_output = runner._make_preempted_batch_output() + + self.assertEqual(sampler_output.sampled_token_ids.shape, [7, 1]) + self.assertIsNotNone(sampler_output.logprobs_tensors) + self.assertEqual(sampler_output.logprobs_tensors.logprob_token_ids.shape, [7, 1]) + self.assertEqual(sampler_output.token_num_per_batch.shape, [7, 1]) + self.assertEqual(sampler_output.cu_batch_token_offset.shape, [8]) + self.assertEqual(runner.share_inputs["accept_tokens_cpu"][:7].numpy().reshape([-1]).tolist(), [0] * 7) + self.assertEqual(runner.share_inputs["accept_num_cpu"][:7].numpy().reshape([-1]).tolist(), [0] * 7) + self.assertEqual( + runner.share_inputs["seq_lens_decoder_cpu"][:7].numpy().reshape([-1]).tolist(), + list(range(7)), + ) + self.assertEqual( + runner.share_inputs["prompt_lens_cpu"][:7].numpy().reshape([-1]).tolist(), + list(range(10, 17)), + ) + self.assertIsNotNone(model_output_data.accept_tokens) + self.assertIsNotNone(model_output_data.accept_num) + + +class TestExecuteModel(unittest.TestCase): + def _make_runner(self): + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.speculative_decoding = False + runner.parallel_config = Mock(use_ep=False) + runner.fd_config = Mock() + runner.fd_config.speculative_config = Mock(method=None) + runner.proposer = Mock(model=Mock()) + runner.forward_meta = Mock() + runner._save_model_output = Mock() + runner._make_preempted_batch_output = Mock(return_value=("model_output", "sampler_output")) + runner._postprocess = Mock() + runner._execute_empty_mtp_input = Mock() + runner._cached_launch_token_num = 0 + runner._cached_real_bsz = 0 + runner.routing_replay_manager = Mock() + + class _ShareInputs(dict): + pass + + share_inputs = _ShareInputs() + share_inputs["seq_lens_this_time_cpu"] = paddle.zeros([2, 1], dtype="int32") + share_inputs["preempted_idx"] = paddle.to_tensor([[1], [0]], dtype="int32") + share_inputs["last_preempted_idx"] = paddle.zeros([2, 1], dtype="int32") + runner.share_inputs = share_inputs + return runner + + def test_execute_model_dispatches_to_normal_path(self): + runner = self._make_runner() + runner.enable_overlap_schedule = False + runner.execute_model_normal = Mock() + runner.execute_model_overlap = Mock() + + runner.execute_model(model_forward_batch=["req"], num_running_requests=1) + + runner.execute_model_normal.assert_called_once_with(["req"], 1) + runner.execute_model_overlap.assert_not_called() + + def test_execute_model_dispatches_to_overlap_path(self): + runner = self._make_runner() + runner.enable_overlap_schedule = True + runner.execute_model_normal = Mock() + runner.execute_model_overlap = Mock() + + runner.execute_model(model_forward_batch=["req"], num_running_requests=1) + + runner.execute_model_overlap.assert_called_once_with(["req"], 1) + runner.execute_model_normal.assert_not_called() + + def test_execute_model_normal_zero_output_flushes_preempted_batch(self): + runner = self._make_runner() + runner._preprocess = Mock(return_value=("model_inputs", "done_idxs", None)) + runner._execute = Mock(return_value=None) + + runner.execute_model_normal() + + runner._make_preempted_batch_output.assert_called_once_with() + np.testing.assert_array_equal(runner.share_inputs["last_preempted_idx"].numpy(), np.array([[1], [0]])) + np.testing.assert_array_equal(runner.share_inputs["preempted_idx"].numpy(), np.array([[0], [0]])) + runner._save_model_output.assert_called_once_with("model_output", "sampler_output") + + def test_execute_model_normal_postprocess_saves_output_after_sync(self): + runner = self._make_runner() + runner.share_inputs["seq_lens_this_time_cpu"] = paddle.to_tensor([[1], [0]], dtype="int32") + runner._preprocess = Mock(return_value=("model_inputs", "done_idxs", None)) + runner._execute = Mock(return_value="model_output") + post_process_event = Mock() + runner._postprocess.return_value = ("model_output_data", "sampler_output", post_process_event) + + runner.execute_model_normal(model_forward_batch=["req"], num_running_requests=1) + + runner._make_preempted_batch_output.assert_not_called() + post_process_event.synchronize.assert_called_once_with() + runner._save_model_output.assert_called_once_with("model_output_data", "sampler_output") + + def test_execute_model_overlap_zero_output_flushes_preempted_batch(self): + runner = self._make_runner() + token_num_event = Mock() + runner._preprocess = Mock(return_value=("model_inputs", "done_idxs", token_num_event)) + runner._execute = Mock(return_value=None) + runner._predict_next_launch_token_num = Mock(return_value=(11, 22)) + runner._cached_model_output_data = None + runner._cached_sampler_output = "cached_sampler" + runner._cached_post_process_event = "cached_event" + + runner.execute_model_overlap() + + token_num_event.synchronize.assert_called_once_with() + runner._make_preempted_batch_output.assert_called_once_with() + np.testing.assert_array_equal(runner.share_inputs["last_preempted_idx"].numpy(), np.array([[1], [0]])) + np.testing.assert_array_equal(runner.share_inputs["preempted_idx"].numpy(), np.array([[0], [0]])) + runner._save_model_output.assert_called_once_with("model_output", "sampler_output") + self.assertIsNone(runner._cached_model_output_data) + self.assertIsNone(runner._cached_sampler_output) + self.assertIsNone(runner._cached_post_process_event) + self.assertEqual(runner._cached_launch_token_num, 11) + self.assertEqual(runner._cached_real_bsz, 22) + + +def _sync_async_set_value(tgt, src): + """Synchronous stand-in for async_set_value used in tests (no CUDA required). + + Writes to real numpy arrays; silently skips Mock objects (untracked share_inputs + fields whose values we do not assert on). + """ + from unittest.mock import MagicMock + + import numpy as np + + if isinstance(tgt, MagicMock): + return # untracked field — nothing to write + if isinstance(src, (int, float, bool)): + tgt[:] = src + elif isinstance(src, (list, np.ndarray)): + tgt[:] = np.array(src).reshape(tgt.shape) + elif hasattr(src, "numpy"): + tgt[:] = src.numpy() + else: + tgt[:] = src + + +class TestInsertTasksV1SplitwiseSuffix(unittest.TestCase): + """Tests for insert_tasks_v1 splitwise_role=\'decode\' + SpecMethod.SUFFIX branch.""" + + def _make_share_inputs(self, bsz=4, max_draft=6): + """Mock-backed share_inputs; only keys we assert on hold real numpy arrays.""" + import numpy as np + + # Keys whose values we want to inspect after the call + tracked = { + "seq_lens_encoder": np.zeros((bsz, 1), dtype=np.int32), + "draft_tokens": np.zeros((bsz, max_draft), dtype=np.int64), + "seq_lens_this_time_buffer": np.zeros((bsz, 1), dtype=np.int32), + "req_ids": [""] * bsz, + "preempted_idx": np.zeros((bsz, 1), dtype=np.int32), + "num_running_requests": 0, + "running_requests_ids": [], + } + + class _SI: + def get_index_by_batch_id(self, batch_id): + return batch_id + + def __getitem__(self, key): + # Return real array for tracked keys; Mock for everything else + if key in tracked: + return tracked[key] + return MagicMock() + + def __setitem__(self, key, value): + tracked[key] = value + + return _SI() + + def _make_runner(self, bsz=4, num_spec_tokens=3): + from unittest.mock import Mock + + from fastdeploy.spec_decode import SpecMethod + from fastdeploy.worker.gpu_model_runner import GPUModelRunner + + runner = GPUModelRunner.__new__(GPUModelRunner) + runner.enable_mm = False + runner.is_pooling_model = False + runner.speculative_decoding = True + runner.spec_method = SpecMethod.SUFFIX + runner.speculative_config = Mock(num_speculative_tokens=num_spec_tokens) + runner.deterministic_logger = None + runner.routing_replay_manager = Mock() + runner.prompt_logprobs_reqs = {} + runner.in_progress_prompt_logprobs = {} + runner.forward_batch_reqs_list = [None] * bsz + runner._cached_launch_token_num = -1 + runner._cached_real_bsz = 0 + runner.exist_prefill_flag = True + runner.proposer = Mock() + runner.sampler = Mock() + runner.model_config = Mock(eos_tokens_lens=1) + runner.share_inputs = self._make_share_inputs(bsz=bsz, max_draft=num_spec_tokens + 2) + + fd_config = Mock() + fd_config.scheduler_config.splitwise_role = "decode" + fd_config.routing_replay_config.enable_routing_replay = False + runner.fd_config = fd_config + runner.scheduler_config = fd_config.scheduler_config + return runner + + def _make_prefill_request(self, idx, draft_token_ids): + from unittest.mock import Mock + + from fastdeploy.engine.request import RequestType + + req = Mock() + req.task_type = Mock(value=RequestType.PREFILL.value) + req.idx = idx + req.request_id = f"req_{idx}" + req.prompt_token_ids = [10, 20, 30] + req.output_token_ids = [99] + req.draft_token_ids = draft_token_ids + req.pooling_params = None + req.guided_json = None + req.guided_regex = None + req.structural_tag = None + req.guided_grammar = None + req.prefill_start_index = 0 + req.prefill_end_index = 3 + req.multimodal_inputs = None + req.get = Mock(return_value=None) + req.eos_token_ids = [2] + req.block_tables = [] + return req + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_draft_tokens_and_seq_lens_written(self, _mock_asv): + """draft_tokens[0:2] and seq_lens_this_time_buffer=2 are written.""" + runner = self._make_runner(num_spec_tokens=3) + req = self._make_prefill_request(idx=0, draft_token_ids=[101, 202, 303]) + runner.insert_tasks_v1([req], num_running_requests=1) + + self.assertEqual(runner.share_inputs["draft_tokens"][0, 0], 101) + self.assertEqual(runner.share_inputs["draft_tokens"][0, 1], 202) + self.assertEqual(runner.share_inputs["seq_lens_this_time_buffer"][0, 0], 2) + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_exist_prefill_flag_cleared(self, _mock_asv): + runner = self._make_runner() + req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2]) + runner.insert_tasks_v1([req], num_running_requests=1) + self.assertFalse(runner.exist_prefill_flag) + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_cached_launch_token_num_incremented(self, _mock_asv): + runner = self._make_runner(num_spec_tokens=3) + runner._cached_launch_token_num = 10 + runner._cached_real_bsz = 2 + req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2]) + runner.insert_tasks_v1([req], num_running_requests=1) + # token_num_one_step = num_speculative_tokens + 1 = 4 + self.assertEqual(runner._cached_launch_token_num, 14) + self.assertEqual(runner._cached_real_bsz, 3) + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_cached_launch_token_num_skipped_when_negative_one(self, _mock_asv): + runner = self._make_runner(num_spec_tokens=3) + runner._cached_launch_token_num = -1 + req = self._make_prefill_request(idx=0, draft_token_ids=[1, 2]) + runner.insert_tasks_v1([req], num_running_requests=1) + self.assertEqual(runner._cached_launch_token_num, -1) + + @patch("fastdeploy.worker.gpu_model_runner.async_set_value", side_effect=_sync_async_set_value) + def test_raises_when_fewer_than_two_draft_tokens(self, _mock_asv): + runner = self._make_runner() + req = self._make_prefill_request(idx=0, draft_token_ids=[42]) + with self.assertRaises(ValueError): + runner.insert_tasks_v1([req], num_running_requests=1) + + if __name__ == "__main__": unittest.main() diff --git a/tests/worker/test_gpu_prompt_logprobs.py b/tests/worker/test_gpu_prompt_logprobs.py index d26bc915339..f12bc4cf3dc 100644 --- a/tests/worker/test_gpu_prompt_logprobs.py +++ b/tests/worker/test_gpu_prompt_logprobs.py @@ -64,6 +64,7 @@ class SpecaulativeConfig: scheduler_config = SchedulerConfig() cache_config = CacheConfig() parallel_config = ParallelConfig() + enable_mm_runtime = model_config.enable_mm def get_max_chunk_tokens(self, mm_max_tokens_per_item=None): return 8192 diff --git a/tests/worker/test_recover_batch_index_sampling_mask.py b/tests/worker/test_recover_batch_index_sampling_mask.py new file mode 100644 index 00000000000..6119faa0685 --- /dev/null +++ b/tests/worker/test_recover_batch_index_sampling_mask.py @@ -0,0 +1,113 @@ +from unittest.mock import Mock + +import numpy as np +import paddle +import pytest + +from fastdeploy.worker.input_batch import recover_batch_index_for_sampler_output + + +def _make_sampler_output(batch_size, with_sampling_mask=True): + """Create a minimal mock SamplerOutput for testing reorder logic.""" + so = Mock() + so.sampled_token_ids = paddle.arange(batch_size, dtype="int64").unsqueeze(1) + so.logprobs_tensors = Mock() + so.logprobs_tensors.logprob_token_ids = paddle.arange(batch_size, dtype="int64").unsqueeze(1) + so.logprobs_tensors.logprobs = paddle.arange(batch_size, dtype="float32").unsqueeze(1) + so.logprobs_tensors.selected_token_ranks = paddle.zeros([batch_size, 1], dtype="int64") + so.token_num_per_batch = None + so.cu_batch_token_offset = None + so.logits = None + + if with_sampling_mask: + so.sampling_mask = [np.array([i * 10, i * 10 + 1, i * 10 + 2]) for i in range(batch_size)] + else: + so.sampling_mask = None + + return so + + +class TestRecoverBatchIndexSamplingMask: + """Test sampling_mask reordering in recover_batch_index_for_sampler_output.""" + + def test_no_sampling_mask_no_error(self): + """SamplerOutput without sampling_mask should not raise.""" + so = _make_sampler_output(batch_size=4, with_sampling_mask=False) + index_to_batch_id = {0: 2, 1: 0, 2: 3, 3: 1} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=True) + + assert so.sampling_mask is None + + def test_sampling_mask_reorder_matches_token_ids(self): + """After reorder, sampling_mask[i] should correspond to sampled_token_ids[i].""" + batch_size = 4 + so = _make_sampler_output(batch_size=batch_size, with_sampling_mask=True) + + original_masks = [m.copy() for m in so.sampling_mask] + + # index_to_batch_id = {0:2, 1:0, 2:3, 3:1} + # src_order = [k for k,v in sorted(..., key=v)] = [1, 3, 0, 2] + # result[i] = src[src_order[i]] + index_to_batch_id = {0: 2, 1: 0, 2: 3, 3: 1} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=True) + + reordered_token_ids = so.sampled_token_ids.numpy().flatten() + for i in range(batch_size): + token_id = int(reordered_token_ids[i]) + expected_mask = original_masks[token_id] + np.testing.assert_array_equal( + so.sampling_mask[i], + expected_mask, + err_msg=f"Position {i}: sampling_mask doesn't match sampled_token_ids", + ) + + def test_identity_reorder_is_noop(self): + """When index_to_batch_id is identity, function returns early without changes.""" + batch_size = 3 + so = _make_sampler_output(batch_size=batch_size, with_sampling_mask=True) + original_masks = [m.copy() for m in so.sampling_mask] + + index_to_batch_id = {0: 0, 1: 1, 2: 2} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=True) + + for i in range(batch_size): + np.testing.assert_array_equal(so.sampling_mask[i], original_masks[i]) + + def test_pd_reorder_disabled_is_noop(self): + """When enable_pd_reorder=False, nothing is reordered.""" + batch_size = 3 + so = _make_sampler_output(batch_size=batch_size, with_sampling_mask=True) + original_masks = [m.copy() for m in so.sampling_mask] + original_token_ids = so.sampled_token_ids.clone() + + index_to_batch_id = {0: 2, 1: 0, 2: 1} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=False) + + assert paddle.equal_all(so.sampled_token_ids, original_token_ids) + for i in range(batch_size): + np.testing.assert_array_equal(so.sampling_mask[i], original_masks[i]) + + def test_sampling_mask_longer_than_sort_len(self): + """Tail elements beyond sort_len are preserved in place.""" + so = _make_sampler_output(batch_size=5, with_sampling_mask=True) + original_masks = [m.copy() for m in so.sampling_mask] + + # Only reorder first 3 positions; positions 3,4 should stay put + index_to_batch_id = {0: 1, 1: 2, 2: 0} + + recover_batch_index_for_sampler_output(so, index_to_batch_id, enable_pd_reorder=True) + + # src_order = [2, 0, 1] + np.testing.assert_array_equal(so.sampling_mask[0], original_masks[2]) + np.testing.assert_array_equal(so.sampling_mask[1], original_masks[0]) + np.testing.assert_array_equal(so.sampling_mask[2], original_masks[1]) + np.testing.assert_array_equal(so.sampling_mask[3], original_masks[3]) + np.testing.assert_array_equal(so.sampling_mask[4], original_masks[4]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/worker/test_reorder_split_prefill_and_decode.py b/tests/worker/test_reorder_split_prefill_and_decode.py index aff9f551cf4..147e9581201 100644 --- a/tests/worker/test_reorder_split_prefill_and_decode.py +++ b/tests/worker/test_reorder_split_prefill_and_decode.py @@ -59,6 +59,7 @@ def create_mock_config(): scheduler_config = Mock(spec=SchedulerConfig) scheduler_config.max_num_seqs = 10 + scheduler_config.max_num_batched_tokens = 2048 speculative_config = Mock(spec=SpeculativeConfig) speculative_config.method = None @@ -83,6 +84,7 @@ def create_mock_config(): fd_config.parallel_config = parallel_config fd_config.structured_outputs_config = structured_outputs_config fd_config.pad_to = 8 + fd_config.enable_mm_runtime = model_config.enable_mm def get_max_chunk_tokens(mm_max_tokens_per_item=None): return 100 diff --git a/tests/xpu_ci/conftest.py b/tests/xpu_ci/conftest.py index ae0c95d727a..dc6e4d30262 100644 --- a/tests/xpu_ci/conftest.py +++ b/tests/xpu_ci/conftest.py @@ -101,6 +101,13 @@ def safe_kill_cmd(cmd): for cmd in commands: safe_kill_cmd(cmd) + try: + # 清理/dev/shm下的所有文件 + subprocess.run("rm -rf /dev/shm/*", shell=True, check=True) + except subprocess.CalledProcessError: + print("Failed to remove files from /dev/shm") + pass + def cleanup_resources(): """