diff --git a/build.py b/build.py index f00703e76..056ec727b 100755 --- a/build.py +++ b/build.py @@ -9,11 +9,37 @@ CRS_PROJ_DIR = Path("/out/crs/proj") OSS_SRC_DIR = Path(os.getenv("SRC", "/src")) CRS_SRC_DIR = Path("/out/crs/src") +CRS_CODEQL_DB_DIR = Path("/out/crs/codeql-db") -def run_ossfuzz_build(): - env = os.environ.copy() - subprocess.run(["/usr/local/bin/compile"], env=env, check=True) +def build_and_create_codeql_database(): + """Run OSS-Fuzz compile under CodeQL trace. + + This both builds the target (jars, fuzz harnesses → /out) and creates the + CodeQL database, in a single invocation. Running compile twice (once normally, + once under trace) would cause Maven to skip recompilation on the second run, + leaving the CodeQL database with stub-only classes. + """ + print(f"Creating CodeQL database from {OSS_SRC_DIR}") + if CRS_CODEQL_DB_DIR.exists(): + shutil.rmtree(CRS_CODEQL_DB_DIR) + + result = subprocess.run( + [ + "codeql", "database", "create", + str(CRS_CODEQL_DB_DIR), + "--language=java", + f"--source-root={OSS_SRC_DIR}", + "--command=/usr/local/bin/compile", + "--overwrite", + ], + check=True, + ) + + if CRS_CODEQL_DB_DIR.exists(): + print(f"CodeQL database created at {CRS_CODEQL_DB_DIR}") + else: + print(f"WARNING: CodeQL database creation failed (ret={result.returncode})") def prepare_crs_src(): @@ -28,10 +54,12 @@ def submit_build_outputs(): subprocess.run(["libCRS", "submit-build-output", "/out", "build"], check=True) subprocess.run(["libCRS", "submit-build-output", str(CRS_PROJ_DIR), "crs/proj"], check=True) subprocess.run(["libCRS", "submit-build-output", str(CRS_SRC_DIR), "crs/src"], check=True) + if CRS_CODEQL_DB_DIR.exists(): + subprocess.run(["libCRS", "submit-build-output", str(CRS_CODEQL_DB_DIR), "crs/codeql-db"], check=True) def main(): - run_ossfuzz_build() + build_and_create_codeql_database() prepare_crs_src() submit_build_outputs() diff --git a/crs/Dockerfile.crs b/crs/Dockerfile.crs index 37c395e66..3c479e2f5 100644 --- a/crs/Dockerfile.crs +++ b/crs/Dockerfile.crs @@ -77,89 +77,89 @@ RUN mkdir /classpath && mkdir /classpath/atl-jazzer && \ ################################################################################# ## Staged Images - Prebuilt atl-libafl-jazzer ################################################################################# -FROM aixcc_afc_builder_base AS atl_jazzer_libafl_builder - -# libAFL requires bindgen which requires libclang. -ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get install -y libclang-dev - -COPY fuzzers/atl-libafl-jazzer /app/crs-cp-java/fuzzers/atl-libafl-jazzer -COPY fuzzers/jazzer-libafl /app/crs-cp-java/fuzzers/jazzer-libafl -WORKDIR /app/crs-cp-java/fuzzers/atl-libafl-jazzer -RUN yes | adduser --disabled-password builder && \ - chown -R builder . && chown -R builder ../jazzer-libafl/ - -USER builder:builder -# Install Rust. -RUN curl https://sh.rustup.rs -sSf | sh -s -- --component llvm-tools --default-toolchain nightly-2025-06-04 -y -ENV PATH="/home/builder/.cargo/bin:${PATH}" -RUN ln -sf /home/builder/.rustup/toolchains/nightly-2025-06-04-x86_64-unknown-linux-gnu /home/builder/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu -RUN echo "build --java_runtime_version=local_jdk_17" >> .bazelrc \ - && echo "build --cxxopt=-stdlib=libc++" >> .bazelrc \ - && echo "build --linkopt=-lc++" >> .bazelrc -RUN echo "build --experimental_repository_downloader_retries=5" >> .bazelrc \ - && echo "build --http_timeout_scaling=2" >> .bazelrc \ - && echo "build --repository_cache=/home/builder/.cache/bazel-repo" >> .bazelrc -RUN bazel build \ - //src/main/java/com/code_intelligence/jazzer:jazzer_standalone_deploy.jar \ - //deploy:jazzer-api \ - //deploy:jazzer-junit \ - //launcher:jazzer -RUN mkdir out && \ - cp $(bazel cquery --output=files //src/main/java/com/code_intelligence/jazzer:jazzer_standalone_deploy.jar) out/jazzer_agent_deploy.jar && \ - cp $(bazel cquery --output=files //launcher:jazzer) out/jazzer_driver && \ - cp $(bazel cquery --output=files //deploy:jazzer-api) out/jazzer_api_deploy.jar && \ - cp $(bazel cquery --output=files //deploy:jazzer-junit) out/jazzer_junit.jar - -USER root -RUN mkdir /classpath && mkdir /classpath/atl-libafl-jazzer && \ - cp out/jazzer_agent_deploy.jar /classpath/atl-libafl-jazzer/jazzer_standalone_deploy.jar && \ - cp out/jazzer_driver /classpath/atl-libafl-jazzer/jazzer && \ - cp out/jazzer_junit.jar /classpath/atl-libafl-jazzer/ && \ - cp out/jazzer_api_deploy.jar /classpath/atl-libafl-jazzer/ +#FROM aixcc_afc_builder_base AS atl_jazzer_libafl_builder +# +## libAFL requires bindgen which requires libclang. +#ENV DEBIAN_FRONTEND=noninteractive +#RUN apt-get update && apt-get install -y libclang-dev +# +#COPY fuzzers/atl-libafl-jazzer /app/crs-cp-java/fuzzers/atl-libafl-jazzer +#COPY fuzzers/jazzer-libafl /app/crs-cp-java/fuzzers/jazzer-libafl +#WORKDIR /app/crs-cp-java/fuzzers/atl-libafl-jazzer +#RUN yes | adduser --disabled-password builder && \ +# chown -R builder . && chown -R builder ../jazzer-libafl/ +# +#USER builder:builder +## Install Rust. +#RUN curl https://sh.rustup.rs -sSf | sh -s -- --component llvm-tools --default-toolchain nightly-2025-06-04 -y +#ENV PATH="/home/builder/.cargo/bin:${PATH}" +#RUN ln -sf /home/builder/.rustup/toolchains/nightly-2025-06-04-x86_64-unknown-linux-gnu /home/builder/.rustup/toolchains/nightly-x86_64-unknown-linux-gnu +#RUN echo "build --java_runtime_version=local_jdk_17" >> .bazelrc \ +# && echo "build --cxxopt=-stdlib=libc++" >> .bazelrc \ +# && echo "build --linkopt=-lc++" >> .bazelrc +#RUN echo "build --experimental_repository_downloader_retries=5" >> .bazelrc \ +# && echo "build --http_timeout_scaling=2" >> .bazelrc \ +# && echo "build --repository_cache=/home/builder/.cache/bazel-repo" >> .bazelrc +#RUN bazel build \ +# //src/main/java/com/code_intelligence/jazzer:jazzer_standalone_deploy.jar \ +# //deploy:jazzer-api \ +# //deploy:jazzer-junit \ +# //launcher:jazzer +#RUN mkdir out && \ +# cp $(bazel cquery --output=files //src/main/java/com/code_intelligence/jazzer:jazzer_standalone_deploy.jar) out/jazzer_agent_deploy.jar && \ +# cp $(bazel cquery --output=files //launcher:jazzer) out/jazzer_driver && \ +# cp $(bazel cquery --output=files //deploy:jazzer-api) out/jazzer_api_deploy.jar && \ +# cp $(bazel cquery --output=files //deploy:jazzer-junit) out/jazzer_junit.jar +# +#USER root +#RUN mkdir /classpath && mkdir /classpath/atl-libafl-jazzer && \ +# cp out/jazzer_agent_deploy.jar /classpath/atl-libafl-jazzer/jazzer_standalone_deploy.jar && \ +# cp out/jazzer_driver /classpath/atl-libafl-jazzer/jazzer && \ +# cp out/jazzer_junit.jar /classpath/atl-libafl-jazzer/ && \ +# cp out/jazzer_api_deploy.jar /classpath/atl-libafl-jazzer/ ################################################################################# ## Staged Images - Concolic Engine Build (/graal-jdk) ################################################################################# -FROM ubuntu:20.04 AS espresso_builder -ARG DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get install -y \ - curl \ - python3.9 \ - python3-pip \ - git \ - zip \ - wget \ - build-essential \ - libstdc++-9-dev - -# copy mx -COPY concolic/graal-concolic/mx /mx - -ENV PATH="/mx:$PATH" -RUN echo Y | mx fetch-jdk labsjdk-ce-21 -RUN python3 -mpip install ninja_syntax - -WORKDIR /graal-jdk -RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.6/cmake-3.31.6-linux-x86_64.tar.gz && \ - tar -xzvf cmake-3.31.6-linux-x86_64.tar.gz && \ - cp -r cmake-3.31.6-linux-x86_64/* /usr/ && \ - rm -rf cmake-3.31.6-linux-x86_64 - -# graal espresso -COPY concolic/graal-concolic/graal-jdk-25-14 /graal-jdk -COPY concolic/graal-concolic/docker-scripts /docker-scripts -WORKDIR /graal-jdk -RUN chmod +x /docker-scripts/* - -# put mx deps -COPY --from=graal_deps /root/.mx /root/.mx - -ENV MODE="jvm-ce" -ENV PREPARE_CMD="pushd /graal-jdk/espresso && mx --env=$MODE build --targets LLVM_TOOLCHAIN && mx --env $MODE create-generated-sources" -RUN /docker-scripts/init_dev.sh /bin/bash -c "$PREPARE_CMD" -ENV BUILD_CMD="pushd /graal-jdk/espresso && export MX_BUILD_EXPLODED=$MX_BUILD_EXPLODED && mx --env $MODE create-generated-sources && mx --env $MODE build" -RUN /docker-scripts/init_dev.sh /bin/bash -c "$BUILD_CMD" +#FROM ubuntu:20.04 AS espresso_builder +#ARG DEBIAN_FRONTEND=noninteractive +#RUN apt-get update && apt-get install -y \ +# curl \ +# python3.9 \ +# python3-pip \ +# git \ +# zip \ +# wget \ +# build-essential \ +# libstdc++-9-dev +# +## copy mx +#COPY concolic/graal-concolic/mx /mx +# +#ENV PATH="/mx:$PATH" +#RUN echo Y | mx fetch-jdk labsjdk-ce-21 +#RUN python3 -mpip install ninja_syntax +# +#WORKDIR /graal-jdk +#RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.6/cmake-3.31.6-linux-x86_64.tar.gz && \ +# tar -xzvf cmake-3.31.6-linux-x86_64.tar.gz && \ +# cp -r cmake-3.31.6-linux-x86_64/* /usr/ && \ +# rm -rf cmake-3.31.6-linux-x86_64 +# +## graal espresso +#COPY concolic/graal-concolic/graal-jdk-25-14 /graal-jdk +#COPY concolic/graal-concolic/docker-scripts /docker-scripts +#WORKDIR /graal-jdk +#RUN chmod +x /docker-scripts/* +# +## put mx deps +#COPY --from=graal_deps /root/.mx /root/.mx +# +#ENV MODE="jvm-ce" +#ENV PREPARE_CMD="pushd /graal-jdk/espresso && mx --env=$MODE build --targets LLVM_TOOLCHAIN && mx --env $MODE create-generated-sources" +#RUN /docker-scripts/init_dev.sh /bin/bash -c "$PREPARE_CMD" +#ENV BUILD_CMD="pushd /graal-jdk/espresso && export MX_BUILD_EXPLODED=$MX_BUILD_EXPLODED && mx --env $MODE create-generated-sources && mx --env $MODE build" +#RUN /docker-scripts/init_dev.sh /bin/bash -c "$BUILD_CMD" ################################################################################# ## Staged Images - Prebuilt Joern (Public) @@ -200,13 +200,26 @@ RUN mkdir -p $JAVA_CRS_SRC && chmod -R 0755 $JAVA_CRS_SRC ## CRS-java fuzzer binaries COPY --from=aixcc_jazzer_builder /classpath/aixcc-jazzer /classpath/aixcc-jazzer COPY --from=atl_jazzer_builder /classpath/atl-jazzer /classpath/atl-jazzer -COPY --from=atl_jazzer_libafl_builder /classpath/atl-libafl-jazzer /classpath/atl-libafl-jazzer +#COPY --from=atl_jazzer_libafl_builder /classpath/atl-libafl-jazzer /classpath/atl-libafl-jazzer COPY ./fuzzers/mock-jazzer /classpath/mock-jazzer ENV AIXCC_JAZZER_DIR=/classpath/aixcc-jazzer ENV ATL_JAZZER_DIR=/classpath/atl-jazzer ENV ATL_JAZZER_LIBAFL_DIR=/classpath/atl-libafl-jazzer ENV ATL_MOCK_JAZZER_DIR=/classpath/mock-jazzer +## CRS-java atl-asm and atl-soot (must run before pip install coordinates) +COPY ./prebuilt ${JAVA_CRS_SRC}/prebuilt +RUN cd ${JAVA_CRS_SRC}/prebuilt && \ + ./mvn_install.sh +ENV JACOCO_CLI_DIR=${JAVA_CRS_SRC}/prebuilt/jacococli + +## joern +COPY --from=joern_builder /opt/joern ${JAVA_CRS_SRC}/joern +ENV JOERN_DIR=${JAVA_CRS_SRC}/joern/Joern +ENV JOERN_CLI=$JOERN_DIR/joern-cli +ENV JAVA2CPG=$JOERN_DIR/joern-cli/frontends/javasrc2cpg/bin +ENV PATH=$PATH:$JAVA_HOME/bin:$JOERN_CLI:$JAVA2CPG + ## crs python package deps COPY ./libs ${JAVA_CRS_SRC}/libs RUN /venv/bin/pip install --no-cache-dir \ @@ -223,19 +236,6 @@ RUN /venv/bin/pip install --no-cache-dir \ ${JAVA_CRS_SRC}/libs/claude-code-sdk-python && \ rm -rf /root/.cache/pip -## joern -COPY --from=joern_builder /opt/joern ${JAVA_CRS_SRC}/joern -ENV JOERN_DIR=${JAVA_CRS_SRC}/joern/Joern -ENV JOERN_CLI=$JOERN_DIR/joern-cli -ENV JAVA2CPG=$JOERN_DIR/joern-cli/frontends/javasrc2cpg/bin -ENV PATH=$PATH:$JAVA_HOME/bin:$JOERN_CLI:$JAVA2CPG - -## CRS-java atl-asm and atl-soot -COPY ./prebuilt ${JAVA_CRS_SRC}/prebuilt -RUN cd ${JAVA_CRS_SRC}/prebuilt && \ - ./mvn_install.sh -ENV JACOCO_CLI_DIR=${JAVA_CRS_SRC}/prebuilt/jacococli - ## jazzer-llm-augmented COPY ./jazzer-llm-augmented ${JAVA_CRS_SRC}/jazzer-llm-augmented #RUN cd ${JAVA_CRS_SRC}/jazzer-llm-augmented/ProgramExecutionTracer && \ @@ -254,6 +254,12 @@ COPY ./codeql ${JAVA_CRS_SRC}/codeql RUN cd ${JAVA_CRS_SRC}/codeql && \ ./init.sh +## filtering-agent (LLM-powered exploitability assessment) +COPY ./filtering-agent ${JAVA_CRS_SRC}/filtering-agent +RUN cd ${JAVA_CRS_SRC}/filtering-agent && \ + if [ -f requirements.txt ]; then /venv/bin/pip install --no-cache-dir -r requirements.txt; fi && \ + rm -rf /root/.cache/pip + ## llm-poc-gen COPY ./llm-poc-gen ${JAVA_CRS_SRC}/llm-poc-gen ENV PATH=${PATH}:/root/.local/bin @@ -271,21 +277,21 @@ RUN cd ${JAVA_CRS_SRC}/expkit && \ /venv/bin/pip install -e . && \ rm -rf /root/.cache/pip -## Build espresso-JDK-dependent components -## concolic executor engine (espresso-jdk & runtime) -# copy binary only -COPY --from=espresso_builder /graal-jdk/sdk/mxbuild/linux-amd64/GRAALVM_ESPRESSO_JVM_CE_JAVA21/ /graal-jdk/sdk/mxbuild/linux-amd64/GRAALVM_ESPRESSO_JVM_CE_JAVA21/ -# copy only the executor and provider -COPY ./concolic/graal-concolic/executor /graal-jdk/concolic/graal-concolic/executor -COPY ./concolic/graal-concolic/provider /graal-jdk/concolic/graal-concolic/provider -COPY ./concolic/graal-concolic/scheduler /graal-jdk/concolic/graal-concolic/scheduler -RUN cd /graal-jdk/concolic/graal-concolic/executor && \ - JAVA_HOME=/graal-jdk/sdk/mxbuild/linux-amd64/GRAALVM_ESPRESSO_JVM_CE_JAVA21/graalvm-espresso-jvm-ce-openjdk-21.0.2+13.1/ ./gradlew build && \ - /venv/bin/pip install --no-cache-dir -r /graal-jdk/concolic/graal-concolic/executor/scripts/requirements.txt && \ - cd /graal-jdk/concolic/graal-concolic/provider && \ - JAVA_HOME=/graal-jdk/sdk/mxbuild/linux-amd64/GRAALVM_ESPRESSO_JVM_CE_JAVA21/graalvm-espresso-jvm-ce-openjdk-21.0.2+13.1/ ./gradlew build && \ - rm -rf /root/.cache/coursier && \ - rm -rf /root/.cache/pip +### Build espresso-JDK-dependent components +### concolic executor engine (espresso-jdk & runtime) +## copy binary only +#COPY --from=espresso_builder /graal-jdk/sdk/mxbuild/linux-amd64/GRAALVM_ESPRESSO_JVM_CE_JAVA21/ /graal-jdk/sdk/mxbuild/linux-amd64/GRAALVM_ESPRESSO_JVM_CE_JAVA21/ +## copy only the executor and provider +#COPY ./concolic/graal-concolic/executor /graal-jdk/concolic/graal-concolic/executor +#COPY ./concolic/graal-concolic/provider /graal-jdk/concolic/graal-concolic/provider +#COPY ./concolic/graal-concolic/scheduler /graal-jdk/concolic/graal-concolic/scheduler +#RUN cd /graal-jdk/concolic/graal-concolic/executor && \ +# JAVA_HOME=/graal-jdk/sdk/mxbuild/linux-amd64/GRAALVM_ESPRESSO_JVM_CE_JAVA21/graalvm-espresso-jvm-ce-openjdk-21.0.2+13.1/ ./gradlew build && \ +# /venv/bin/pip install --no-cache-dir -r /graal-jdk/concolic/graal-concolic/executor/scripts/requirements.txt && \ +# cd /graal-jdk/concolic/graal-concolic/provider && \ +# JAVA_HOME=/graal-jdk/sdk/mxbuild/linux-amd64/GRAALVM_ESPRESSO_JVM_CE_JAVA21/graalvm-espresso-jvm-ce-openjdk-21.0.2+13.1/ ./gradlew build && \ +# rm -rf /root/.cache/coursier && \ +# rm -rf /root/.cache/pip ## dictgen COPY ./dictgen ${JAVA_CRS_SRC}/dictgen @@ -330,13 +336,11 @@ RUN cd ${JAVA_CRS_SRC}/deepgen/jvm/stuck-point-analyzer && \ rm -rf /root/.cache/pip ## crs-java main entry -COPY ./*.sh ./*.py ./requirements.txt ./jazzer_driver_stub ./crs-java.config ./sink-targets.txt ${JAVA_CRS_SRC}/ +COPY ./*.sh ./*.py ./requirements.txt ./jazzer_driver_stub ./crs-java.config ${JAVA_CRS_SRC}/ COPY ./javacrs_modules ${JAVA_CRS_SRC}/javacrs_modules COPY ./tests ${JAVA_CRS_SRC}/tests RUN /venv/bin/pip install --no-cache-dir -r ${JAVA_CRS_SRC}/requirements.txt && \ rm -rf /root/.cache/pip -ENV JAVA_CRS_SINK_TARGET_CONF=${JAVA_CRS_SRC}/sink-targets.txt -ENV JAVA_CRS_CUSTOM_SINK_YAML=${JAVA_CRS_SRC}/codeql/sink_definitions.yml ## git setup # git/python-git will not work if CP repo is of unknown user: @@ -352,10 +356,5 @@ COPY --from=aixcc_afc_builder_base /usr/local/bin/jazzer_driver /classpath/raw-j COPY --from=aixcc_afc_builder_base /usr/local/bin/jazzer_junit.jar /classpath/raw-jazzer/ COPY --from=aixcc_afc_builder_base /usr/local/lib/jazzer_api_deploy.jar /classpath/raw-jazzer/ -COPY ssmode-lpg.toml ${JAVA_CRS_SRC}/ssmode-lpg.toml -COPY ssmode-sink.txt ${JAVA_CRS_SRC}/ssmode-sink.txt -RUN mkdir -p ${JAVA_CRS_SRC}/llm-poc-gen/eval/sheet && \ - cp ${JAVA_CRS_SRC}/ssmode-lpg.toml ${JAVA_CRS_SRC}/llm-poc-gen/eval/sheet/cpv.toml - WORKDIR ${JAVA_CRS_SRC} CMD ${JAVA_CRS_SRC}/run-crs-java.sh diff --git a/crs/assets/r2-test-config/r2-apache-commons-compress-diff-1.config b/crs/assets/r2-test-config/r2-apache-commons-compress-diff-1.config index d95e9a3f4..58a468da8 100644 --- a/crs/assets/r2-test-config/r2-apache-commons-compress-diff-1.config +++ b/crs/assets/r2-test-config/r2-apache-commons-compress-diff-1.config @@ -85,8 +85,7 @@ "enabled": true, "exp_time": 300, "monitor_interval": 3, - "gen_models": "o3:2,claude-opus-4-20250514:1,gemini-2.5-pro-preview-05-06:1,grok-3-mini-beta:1,none:15", - "x_models": "gpt-4.1:2,claude-sonnet-4-20250514:1,gemini-2.5-flash-preview-05-20:1,grok-3-mini-beta:1" + "gen_model": "gpt-5" }, "stub": { "enabled": false diff --git a/crs/assets/r2-test-config/r2-apache-commons-compress-diff-2.config b/crs/assets/r2-test-config/r2-apache-commons-compress-diff-2.config index d5856f6bd..dbac5bf9d 100644 --- a/crs/assets/r2-test-config/r2-apache-commons-compress-diff-2.config +++ b/crs/assets/r2-test-config/r2-apache-commons-compress-diff-2.config @@ -85,8 +85,7 @@ "enabled": true, "exp_time": 300, "monitor_interval": 3, - "gen_models": "o3:2,claude-opus-4-20250514:1,gemini-2.5-pro-preview-05-06:1,grok-3-mini-beta:1,none:15", - "x_models": "gpt-4.1:2,claude-sonnet-4-20250514:1,gemini-2.5-flash-preview-05-20:1,grok-3-mini-beta:1" + "gen_model": "gpt-5" }, "stub": { "enabled": false diff --git a/crs/assets/r2-test-config/r2-apache-commons-compress.config b/crs/assets/r2-test-config/r2-apache-commons-compress.config index b64a73b7f..724f658eb 100644 --- a/crs/assets/r2-test-config/r2-apache-commons-compress.config +++ b/crs/assets/r2-test-config/r2-apache-commons-compress.config @@ -85,8 +85,7 @@ "enabled": true, "exp_time": 300, "monitor_interval": 3, - "gen_models": "o3:2,claude-opus-4-20250514:1,gemini-2.5-pro-preview-05-06:1,grok-3-mini-beta:1,none:15", - "x_models": "gpt-4.1:2,claude-sonnet-4-20250514:1,gemini-2.5-flash-preview-05-20:1,grok-3-mini-beta:1" + "gen_model": "gpt-5" }, "stub": { "enabled": false diff --git a/crs/assets/r2-test-config/r2-zookeeper-diff-1.config b/crs/assets/r2-test-config/r2-zookeeper-diff-1.config index 520feb0d6..ba963841c 100644 --- a/crs/assets/r2-test-config/r2-zookeeper-diff-1.config +++ b/crs/assets/r2-test-config/r2-zookeeper-diff-1.config @@ -85,8 +85,7 @@ "enabled": true, "exp_time": 300, "monitor_interval": 3, - "gen_models": "o3:2,claude-opus-4-20250514:1,gemini-2.5-pro-preview-05-06:1,grok-3-mini-beta:1,none:15", - "x_models": "gpt-4.1:2,claude-sonnet-4-20250514:1,gemini-2.5-flash-preview-05-20:1,grok-3-mini-beta:1" + "gen_model": "gpt-5" }, "stub": { "enabled": false diff --git a/crs/assets/r2-test-config/r2-zookeeper.config b/crs/assets/r2-test-config/r2-zookeeper.config index 62dcb26bf..6ce175611 100644 --- a/crs/assets/r2-test-config/r2-zookeeper.config +++ b/crs/assets/r2-test-config/r2-zookeeper.config @@ -85,8 +85,7 @@ "enabled": true, "exp_time": 300, "monitor_interval": 3, - "gen_models": "o3:2,claude-opus-4-20250514:1,gemini-2.5-pro-preview-05-06:1,grok-3-mini-beta:1,none:15", - "x_models": "gpt-4.1:2,claude-sonnet-4-20250514:1,gemini-2.5-flash-preview-05-20:1,grok-3-mini-beta:1" + "gen_model": "gpt-5" }, "stub": { "enabled": false diff --git a/crs/assets/sarif-test-config/r2-full-mode-apache-common-compress.config b/crs/assets/sarif-test-config/r2-full-mode-apache-common-compress.config index 5724a7f13..50cb6ab8f 100644 --- a/crs/assets/sarif-test-config/r2-full-mode-apache-common-compress.config +++ b/crs/assets/sarif-test-config/r2-full-mode-apache-common-compress.config @@ -88,8 +88,7 @@ "enabled": true, "exp_time": 300, "monitor_interval": 3, - "gen_models": "o3:2,claude-opus-4-20250514:1,gemini-2.5-pro-preview-05-06:1,grok-3-mini-beta:1,none:15", - "x_models": "gpt-4.1:2,claude-sonnet-4-20250514:1,gemini-2.5-flash-preview-05-20:1,grok-3-mini-beta:1" + "gen_model": "gpt-5" }, "dictgen": { "enabled": true, diff --git a/crs/codeql/README.md b/crs/codeql/README.md index b1116faf2..62924f2aa 100644 --- a/crs/codeql/README.md +++ b/crs/codeql/README.md @@ -1,51 +1,54 @@ # CodeQL Sink Analysis Tool -This tool runs a CodeQL query to identify additional security sinks in Java code and transforms the results into a coordinate-based format suitable for further analysis. +Runs a curated set of CWE-specific CodeQL queries against a Java database and +emits a coordinate-format JSON list of sink locations suitable for downstream +consumption by the CRS sink pipeline. ## Setup -Before running the analysis, you need to initialize the project: - ```bash ./init.sh ``` -This will: -1. Install required Python dependencies (PyYAML, Jinja2) -2. Generate CodeQL model and query files from centralized sink definitions -3. Install the CodeQL pack - -This requires: -1. CodeQL CLI installed and available in PATH -2. Python 3.x +This installs Python dependencies and fetches the CodeQL packs declared in +`sink-queries/qlpack.yml` (notably `codeql/java-all`) into the local CodeQL +package cache. Requires `codeql` in `PATH` and Python 3. ## Usage -### Basic Usage - ```bash -./run.sh +CODEQL_CWES=CWE-078,CWE-089 ./run.sh [threads] ``` **Parameters:** -- `database_path` - Path to the CodeQL database to analyze -- `output_json_path` - Path where the transformed coordinate format JSON will be saved +- `database_path` — CodeQL database to analyze (must be finalized) +- `output_json_path` — where the coordinate-format JSON will be written +- `threads` — optional, passed to `codeql database analyze` (default: 4) -**Example:** -```bash -./run.sh test-db results.json -``` +**Environment:** +- `CODEQL_CWES` — required, comma-separated subset of supported CWEs to run. + The script hardcodes an explicit CWE → query file mapping; unknown CWEs + are rejected. + +### What `run.sh` does -### What the Script Does +1. Validates `CODEQL_CWES` and resolves each CWE to its `.ql` file. +2. Runs `codeql database analyze` against the database with the selected + queries, writing per-query SARIF to a `_sarifs` directory adjacent to the + output path. +3. Transforms all SARIF files into a single coordinate-format JSON via + `transform_results.py`. -1. **Runs CodeQL Query**: Executes the sink detection query against the specified database -2. **Decodes Results**: Converts BQRS output to JSON format (temporarily) -3. **Transforms Format**: Converts the CodeQL JSON format to coordinate format -4. **Outputs Results**: Saves the final coordinate format to the specified output file +### Reachability filter -### Output Format +After `run.sh`, `find_reachable_sinks.py` can annotate each sink with a +`reachable` flag by cross-referencing a call graph. The sinkdetection module +runs this automatically with `--output ` so the raw CodeQL +results are preserved. -The script outputs JSON in coordinate format where each entry looks like: +## Output format + +Each entry in the output JSON looks like: ```json { @@ -53,89 +56,36 @@ The script outputs JSON in coordinate format where each entry looks like: "line_num": 342, "method_name": "tokenizeRow", "file_name": "BasicCParser.java", - "bytecode_offset": -1, - "method_desc": "(Ljava/lang/String;)[Ljava/lang/String;", - "mark_desc": "sink-RegexInjection", - "method_signature": "org.apache.commons.imaging.common.BasicCParser: java.lang.String[] tokenizeRow(java.lang.String)", - "class_name": "org/apache/commons/imaging/common/BasicCParser" + "class_name": "org/apache/commons/imaging/common/BasicCParser", + ... }, - "id": "Sink: java.util.regex; Pattern; false; compile; (String); static; Argument[0]; regex-use; manual" + "cwe": "CWE-730", + "filtered_out_flow": false, + "filtered_out_test": false } ``` -## Metadata Retrieval - -You can retrieve metadata for any sink definition using its ID: - -```bash -./get_metadata.sh "" -``` - -**Example:** -```bash -./get_metadata.sh "Sink: java.io; File; false; ; (String); ; Argument[0]; path-injection; manual" -``` - -This will output whatever metadata is defined in the sink_definitions.yml file in YAML format, for example: -```yaml -category: file-system -cwe: CWE-22 -description: File constructor that accepts a pathname string -severity: medium -``` - -The sink ID can be retrieved from the analysis output file. - -## Architecture +## Layout -This project uses a centralized approach for managing sink definitions with separated model and metadata components. - -### File Structure - -``` -├── sink_definitions.yml # Central sink definitions with model and metadata -├── scripts/ # Python scripts -│ ├── generate_models.py # Generates CodeQL model and query files -│ ├── get_metadata.py # Retrieves metadata by sink ID -│ └── transform_results.py # Transforms CodeQL results to coordinate format -├── templates/ # Jinja2 templates for code generation -│ ├── model.yml.j2 # Template for CodeQL model files -│ └── sinks.ql.j2 # Template for CodeQL query file -├── sinks-pack/ # Generated CodeQL pack -│ ├── models/ # Generated model files (one per package) -│ │ ├── java.io.model.yml -│ │ ├── java.lang.model.yml -│ │ └── ... -│ └── queries/ -│ └── sinks.ql # Generated query file -├── init.sh # Initialization script -├── run.sh # Analysis script -└── get_metadata.sh # Metadata retrieval script ``` - -### Adding New Sink Definitions - -Edit `sink_definitions.yml` and add entries with `model` (CodeQL fields) and `metadata` (additional info) sections: - -```yaml -sink_definitions: - - model: - package: "java.example" - type: "ExampleClass" - subtypes: false - name: "vulnerableMethod" - signature: "(String)" - ext: "" - input: "Argument[0]" - kind: "example-injection" - provenance: "manual" - metadata: - description: "Description of the sink" - category: "example-category" - severity: "medium" - cwe: "CWE-XXX" +├── init.sh # Installs deps and the CodeQL pack +├── run.sh # Runs the selected CWE queries +├── transform_results.py # SARIF -> coordinate JSON +├── find_reachable_sinks.py # Annotates sinks with call-graph reachability +├── requirements.txt # Python deps for init.sh +└── sink-queries/ # QL pack + ├── qlpack.yml # Declares codeql/java-all dependency + └── queries/ + ├── CWE-022-path-traversal.ql + ├── CWE-078-command-injection.ql + ├── CWE-089-sql-injection.ql + ├── CWE-090-ldap-injection.ql + ├── CWE-094-script-injection.ql + ├── CWE-117-log-injection.ql + ├── CWE-470-unsafe-reflection.ql + ├── CWE-502-unsafe-deserialization.ql + ├── CWE-611-xxe.ql + ├── CWE-643-xpath-injection.ql + ├── CWE-730-regex-injection.ql + └── CWE-918-ssrf.ql ``` - -After adding definitions, run `./init.sh` to regenerate the pack. - -For more information on CodeQL model definitions, see the [CodeQL documentation](https://codeql.github.com/docs/codeql-language-guides/customizing-library-models-for-java-and-kotlin/). diff --git a/crs/codeql/find_reachable_sinks.py b/crs/codeql/find_reachable_sinks.py new file mode 100755 index 000000000..51504d4cf --- /dev/null +++ b/crs/codeql/find_reachable_sinks.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +""" +Find sink points reachable from fuzzing entry points in a call graph. + +This script identifies all sink points that are reachable from any fuzzing +entry point (methods named 'fuzzerTestOneInput') in the call graph. +""" + +import json +import argparse +from collections import deque +from typing import Set, Dict, List + + +def load_json_file(filepath: str) -> dict: + """Load and parse a JSON file.""" + with open(filepath, 'r') as f: + return json.load(f) + + +def find_entry_points(nodes: List[dict]) -> Set[int]: + """ + Find all fuzzing entry points in the call graph. + + Entry points are identified by method name 'fuzzerTestOneInput'. + + Args: + nodes: List of node dictionaries from the call graph + + Returns: + Set of node IDs that are fuzzing entry points + """ + entry_points = set() + for node in nodes: + func_name = node.get('data', {}).get('func_name', '') + if func_name == 'fuzzerTestOneInput': + entry_points.add(node['id']) + return entry_points + + +def build_call_graph(links: List[dict]) -> Dict[int, Set[int]]: + """ + Build a call graph as an adjacency list. + + Args: + links: List of link dictionaries with 'source' and 'target' keys + + Returns: + Dictionary mapping source node ID to set of target node IDs + """ + graph = {} + for link in links: + source = link['source'] + target = link['target'] + if source not in graph: + graph[source] = set() + graph[source].add(target) + return graph + + +def find_reachable_nodes(entry_points: Set[int], graph: Dict[int, Set[int]]) -> Set[int]: + """ + Find all nodes reachable from any entry point using BFS. + + Args: + entry_points: Set of entry point node IDs + graph: Call graph as adjacency list + + Returns: + Set of all reachable node IDs (including entry points) + """ + reachable = set() + queue = deque(entry_points) + reachable.update(entry_points) + + while queue: + current = queue.popleft() + + # Get all targets called by current node + targets = graph.get(current, set()) + for target in targets: + if target not in reachable: + reachable.add(target) + queue.append(target) + + return reachable + + +def create_node_index(nodes: List[dict]) -> Dict[int, dict]: + """ + Create an index mapping node IDs to node data. + + Args: + nodes: List of node dictionaries + + Returns: + Dictionary mapping node ID to node data + """ + return {node['id']: node for node in nodes} + + +def match_sink_to_node(sink: dict, nodes: List[dict]) -> int: + """ + Match a sink point to a node in the call graph. + + Args: + sink: Sink dictionary with coord information + nodes: List of node dictionaries from call graph + + Returns: + Node ID if match found, None otherwise + """ + sink_file = sink['coord']['file_name'] + sink_line = sink['coord']['line_num'] + + for node in nodes: + node_data = node.get('data', {}) + node_file = node_data.get('file_name', '') + start_line = node_data.get('start_line', -1) + end_line = node_data.get('end_line', -1) + + # Handle path prefix differences between call graph and sink files + # CG may have "repo/" or "oss-fuzz/projects/aixcc/jvm//" prefixes + node_file_normalized = node_file + if node_file.startswith('repo/'): + node_file_normalized = node_file[5:] # Remove "repo/" + elif 'oss-fuzz/projects/aixcc/jvm/' in node_file: + # Remove prefix up to and including the project directory + parts = node_file.split('oss-fuzz/projects/aixcc/jvm/', 1) + if len(parts) > 1: + # Also skip the project name directory + remaining = parts[1].split('/', 1) + if len(remaining) > 1: + node_file_normalized = remaining[1] + + # Match if normalized paths match and line is within range + if node_file_normalized == sink_file and start_line <= sink_line <= end_line: + return node['id'] + + return None + + +def main(): + parser = argparse.ArgumentParser( + description='Find sink points reachable from fuzzing entry points' + ) + parser.add_argument( + 'sinks_file', + help='Path to JSON file containing sink points (e.g., geonetwork.json)' + ) + parser.add_argument( + 'callgraph_file', + help='Path to JSON file containing call graph (e.g., joern-cg.json)' + ) + parser.add_argument( + '--output', + '-o', + help='Output file for reachable sinks (default: stdout)' + ) + parser.add_argument( + '--in-place', + '-i', + action='store_true', + help='Modify the sinks file in-place instead of creating a new file' + ) + parser.add_argument( + '--include-filtered', + action='store_true', + help='Include sinks that were filtered out' + ) + parser.add_argument( + '--verbose', + '-v', + action='store_true', + help='Print verbose information' + ) + + args = parser.parse_args() + + # Validate arguments + if args.in_place and args.output: + parser.error("Cannot use both --in-place and --output") + + # Load input files + if args.verbose: + print(f"Loading sinks from {args.sinks_file}...") + sinks = load_json_file(args.sinks_file) + + if args.verbose: + print(f"Loading call graph from {args.callgraph_file}...") + callgraph = load_json_file(args.callgraph_file) + + nodes = callgraph['nodes'] + links = callgraph['links'] + + # Find entry points + if args.verbose: + print("Finding fuzzing entry points...") + entry_points = find_entry_points(nodes) + + if args.verbose: + print(f"Found {len(entry_points)} entry point(s)") + node_index = create_node_index(nodes) + for ep_id in entry_points: + ep_data = node_index[ep_id]['data'] + print(f" - {ep_data.get('class_name', '')}.{ep_data.get('func_name', '')} " + f"at {ep_data.get('file_name', '')}:{ep_data.get('start_line', '')}") + + if not entry_points: + print("WARNING: No fuzzing entry points found!") + if args.output: + with open(args.output, 'w') as f: + json.dump([], f, indent=2) + else: + print("[]") + return + + # Build call graph + if args.verbose: + print("Building call graph...") + graph = build_call_graph(links) + + # Find reachable nodes + if args.verbose: + print("Finding reachable nodes...") + reachable = find_reachable_nodes(entry_points, graph) + + if args.verbose: + print(f"Found {len(reachable)} reachable node(s)") + + # Match sinks to nodes and add reachability information + if args.verbose: + print("Matching sinks to reachable nodes...") + + reachable_count = 0 + unmatched_count = 0 + output_sinks = [] + + for sink in sinks: + # Skip filtered sinks unless requested + if not args.include_filtered: + if sink.get('filtered_out_flow', False) or sink.get('filtered_out_test', False): + continue + + # Create a copy of the sink to add reachability info + sink_with_reachability = sink.copy() + + node_id = match_sink_to_node(sink, nodes) + + if node_id is not None: + if node_id in reachable: + sink_with_reachability['reachable'] = True + reachable_count += 1 + else: + sink_with_reachability['reachable'] = False + else: + # Sink not matched to any node in call graph + sink_with_reachability['reachable'] = False + unmatched_count += 1 + + output_sinks.append(sink_with_reachability) + + # Output results + print(f"\nResults:") + print(f" Total sinks processed: {len(output_sinks)}") + print(f" Reachable sinks: {reachable_count}") + print(f" Unreachable sinks: {len(output_sinks) - reachable_count}") + print(f" Unmatched sinks (not in CG): {unmatched_count}") + + # Determine output file + if args.in_place: + output_file = args.sinks_file + elif args.output: + output_file = args.output + else: + output_file = None + + if output_file: + with open(output_file, 'w') as f: + json.dump(output_sinks, f, indent=2) + print(f"\nSinks with reachability info written to {output_file}") + else: + # Print human-readable summary to stdout + print(f"\nReachable sinks:") + for sink in output_sinks: + if sink.get('reachable', False): + print(f" [{sink.get('cwe', 'N/A')}] {sink['coord']['file_name']}:{sink['coord']['line_num']}") + + print(f"\nUnreachable sinks:") + for sink in output_sinks: + if not sink.get('reachable', False): + print(f" [{sink.get('cwe', 'N/A')}] {sink['coord']['file_name']}:{sink['coord']['line_num']}") + + +if __name__ == '__main__': + main() diff --git a/crs/codeql/get_metadata.sh b/crs/codeql/get_metadata.sh deleted file mode 100755 index 040255fe2..000000000 --- a/crs/codeql/get_metadata.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -set -e -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) - -# Usage: ./get_metadata.sh - -if [ $# -ne 1 ]; then - echo "Usage: $0 " - echo "Example: $0 'Sink: java.io; File; false; ; (String); ; Argument[0]; path-injection; manual'" - exit 1 -fi - -SINK_ID="$1" - -cd "$SCRIPT_DIR" - -# Retrieve metadata for the given sink ID -python3 scripts/get_metadata.py "$SINK_ID" diff --git a/crs/codeql/init.sh b/crs/codeql/init.sh index 69e56e50b..cca72f1ae 100755 --- a/crs/codeql/init.sh +++ b/crs/codeql/init.sh @@ -9,13 +9,10 @@ cd "$SCRIPT_DIR" echo "Installing Python dependencies..." pip3 install -r requirements.txt -# Generate model and query files from sink definitions -echo "Generating CodeQL model and query files..." -python3 scripts/generate_models.py - -# Install CodeQL pack +# Install CodeQL pack (downloads codeql/java-all and transitive deps into +# the CodeQL package cache so the analyze step can resolve library imports) echo "Installing CodeQL pack..." -cd sinks-pack +cd sink-queries codeql pack install echo "Initialization complete!" diff --git a/crs/codeql/run.sh b/crs/codeql/run.sh index cbd99db7e..233f9b6af 100755 --- a/crs/codeql/run.sh +++ b/crs/codeql/run.sh @@ -1,37 +1,66 @@ #!/bin/bash - set -x set -e -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -# Usage: ./run.sh +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -if [ $# -ne 2 ]; then - echo "Usage: $0 " - echo "Example: $0 test-db results.json" - exit 1 -fi +# Usage: ./run.sh [threads] +# Environment: CODEQL_CWES - required, comma-separated CWE IDs (e.g. "CWE-078,CWE-089"). DATABASE_PATH="$1" JSON_OUTPUT="$2" +THREADS="${3:-4}" +if [ -z "$CODEQL_CWES" ]; then + echo "ERROR: CODEQL_CWES is not set" + exit 1 +fi + +SARIF_DIR="${JSON_OUTPUT%.json}_sarifs" +mkdir -p "$SARIF_DIR" cd "$SCRIPT_DIR" -# Create temporary files for BQRS and raw JSON output -TEMP_BQRS=$(mktemp --suffix=.bqrs) -trap "rm -f $TEMP_BQRS" EXIT +# Explicit CWE → query file mapping. Only CWEs listed here can be run. +declare -A CWE_QUERIES=( + [CWE-022]="sink-queries/queries/CWE-022-path-traversal.ql" + [CWE-078]="sink-queries/queries/CWE-078-command-injection.ql" + [CWE-089]="sink-queries/queries/CWE-089-sql-injection.ql" + [CWE-090]="sink-queries/queries/CWE-090-ldap-injection.ql" + [CWE-094]="sink-queries/queries/CWE-094-script-injection.ql" + [CWE-117]="sink-queries/queries/CWE-117-log-injection.ql" + [CWE-470]="sink-queries/queries/CWE-470-unsafe-reflection.ql" + [CWE-502]="sink-queries/queries/CWE-502-unsafe-deserialization.ql" + [CWE-611]="sink-queries/queries/CWE-611-xxe.ql" + [CWE-643]="sink-queries/queries/CWE-643-xpath-injection.ql" + [CWE-730]="sink-queries/queries/CWE-730-regex-injection.ql" + [CWE-918]="sink-queries/queries/CWE-918-ssrf.ql" +) + +# Resolve requested CWEs via the mapping. +queries=() +IFS=',' read -ra selected <<< "$CODEQL_CWES" +for cwe in "${selected[@]}"; do + ql="${CWE_QUERIES[$cwe]}" + if [ -z "$ql" ]; then + echo "ERROR: Unknown CWE: $cwe" + exit 1 + fi + queries+=("$ql") +done -# Intermediate JSON file for decoding -interim_json="${JSON_OUTPUT%.json}_raw.json" +if [ ${#queries[@]} -eq 0 ]; then + echo "ERROR: No queries found for CODEQL_CWES=$CODEQL_CWES" + exit 1 +fi -codeql query run --database="$DATABASE_PATH" sinks-pack/queries/sinks.ql --output="$TEMP_BQRS" +echo "Running ${#queries[@]} CodeQL queries on $DATABASE_PATH (threads=$THREADS)" -echo -# Decode to temporary JSON file first -codeql bqrs decode --format=json --output="$interim_json" "$TEMP_BQRS" -#codeql bqrs decode "$TEMP_BQRS" +codeql database analyze "$DATABASE_PATH" \ + --format=sarifv2.1.0 \ + --output="$SARIF_DIR/results.sarif" \ + --threads="$THREADS" \ + --rerun \ + "${queries[@]}" -echo -echo "Transforming results to coordinate format..." -# Transform the temporary JSON to the final coordinate format -python3 scripts/transform_results.py "$interim_json" "$JSON_OUTPUT" +# Transform SARIF to coordinate JSON +python3 transform_results.py "$SARIF_DIR"/*.sarif "$JSON_OUTPUT" diff --git a/crs/codeql/scripts/generate_models.py b/crs/codeql/scripts/generate_models.py deleted file mode 100644 index 50e90cde0..000000000 --- a/crs/codeql/scripts/generate_models.py +++ /dev/null @@ -1,226 +0,0 @@ -#!/usr/bin/env python3 -""" -Generate CodeQL model and query files from centralized sink definitions. -""" - -import yaml -from jinja2 import Environment, FileSystemLoader -from pathlib import Path -from collections import defaultdict -import sys -import os -from dataclasses import dataclass -from typing import List, Dict, Set - - -@dataclass -class Sink: - """Represents a sink definition with all required properties.""" - package: str - class_name: str - subtypes: bool - name: str - signature: str - ext: str - input_arg: str - kind: str - provenance: str - metadata: Dict - - @classmethod - def from_dict(cls, data: Dict) -> 'Sink': - """Create a Sink instance from a dictionary.""" - model_data = data['model'] - return cls( - package=model_data['package'], - class_name=model_data['type'], - subtypes=model_data['subtypes'], - name=model_data['name'], - signature=model_data['signature'], - ext=model_data['ext'], - input_arg=model_data['input'], - kind=model_data['kind'], - provenance=model_data['provenance'], - metadata=data['metadata'] - ) - - def to_model_tuple(self) -> List: - """Convert sink to model tuple format for CodeQL.""" - return [ - self.package, - self.class_name, - self.subtypes, - self.name, - self.signature, - self.ext, - self.input_arg, - self.kind, - self.provenance - ] - - def get_id(self) -> str: - """Generate the ID string for this sink definition.""" - subtypes_str = "true" if self.subtypes else "false" - return f"Sink: {self.package}; {self.class_name}; {subtypes_str}; {self.name}; {self.signature}; {self.ext}; {self.input_arg}; {self.kind}; {self.provenance}" - - -def load_sink_definitions(file_path) -> List[Sink]: - """Load sink definitions from YAML file and return list of Sink objects.""" - try: - with open(file_path, 'r') as f: - data = yaml.safe_load(f) - except FileNotFoundError: - print(f"Error: Sink definitions file '{file_path}' not found", file=sys.stderr) - sys.exit(1) - except yaml.YAMLError as e: - print(f"Error parsing YAML file '{file_path}': {e}", file=sys.stderr) - sys.exit(1) - - if 'sink_definitions' not in data: - print("Error: Invalid sink definitions format - missing 'sink_definitions' key", file=sys.stderr) - sys.exit(1) - - # Convert dictionary data to Sink objects - sinks = [] - for sink_data in data['sink_definitions']: - try: - sink = Sink.from_dict(sink_data) - sinks.append(sink) - except KeyError as e: - print(f"Error: Missing required field {e} in sink definition: {sink_data}", file=sys.stderr) - sys.exit(1) - - return sinks - - -def group_sinks_by_package(sinks: List[Sink]) -> Dict[str, List[Sink]]: - """Group sink definitions by package for separate model files.""" - grouped = defaultdict(list) - for sink in sinks: - grouped[sink.package].append(sink) - return grouped - - -def generate_model_files(grouped_sinks, template_env, output_dir): - """Generate model files for each package.""" - model_template = template_env.get_template('model.yml.j2') - - # Ensure output directory exists - output_dir.mkdir(parents=True, exist_ok=True) - - generated_files = [] - for package, sinks in grouped_sinks.items(): - # Convert package name to filename (e.g., java.io -> java.io.model.yml) - output_file = output_dir / f"{package}.model.yml" - content = model_template.render(sinks=sinks) - - with open(output_file, 'w') as f: - f.write(content) - - generated_files.append(output_file) - print(f"Generated model file: {output_file}") - - return generated_files - - -def generate_query_file(sinks: List[Sink], template_env, output_file): - """Generate the sinks.ql file with all sink types.""" - query_template = template_env.get_template('sinks.ql.j2') - - # Extract unique sink kinds - sink_kinds = set() - for sink in sinks: - sink_kinds.add(sink.kind) - - # Sort sink kinds for consistent output - sorted_sink_kinds = sorted(sink_kinds) - - content = query_template.render(sink_types=sorted_sink_kinds) - - # Ensure output directory exists - output_file.parent.mkdir(parents=True, exist_ok=True) - - with open(output_file, 'w') as f: - f.write(content) - - print(f"Generated query file: {output_file}") - print(f"Included sink kinds: {', '.join(sorted_sink_kinds)}") - - -def clean_old_model_files(models_dir, generated_files): - """Remove old model files that are no longer generated.""" - if not models_dir.exists(): - return - - # Get all existing .model.yml files - existing_files = list(models_dir.glob("*.model.yml")) - generated_file_names = {f.name for f in generated_files} - - removed_count = 0 - for existing_file in existing_files: - if existing_file.name not in generated_file_names: - print(f"Removing old model file: {existing_file}") - existing_file.unlink() - removed_count += 1 - - if removed_count > 0: - print(f"Removed {removed_count} old model files") - else: - print("No old model files to remove") - - -def main(): - # Setup paths - script_dir = Path(__file__).parent - repo_root = script_dir.parent - sink_defs_file = repo_root / "sink_definitions.yml" - templates_dir = repo_root / "templates" - models_dir = repo_root / "sinks-pack" / "models" - queries_dir = repo_root / "sinks-pack" / "queries" - - # Validate required directories and files - if not templates_dir.exists(): - print(f"Error: Templates directory '{templates_dir}' does not exist", file=sys.stderr) - sys.exit(1) - - if not (templates_dir / "model.yml.j2").exists(): - print(f"Error: Model template '{templates_dir / 'model.yml.j2'}' does not exist", file=sys.stderr) - sys.exit(1) - - if not (templates_dir / "sinks.ql.j2").exists(): - print(f"Error: Query template '{templates_dir / 'sinks.ql.j2'}' does not exist", file=sys.stderr) - sys.exit(1) - - # Load sink definitions - print(f"Loading sink definitions from: {sink_defs_file}") - sinks = load_sink_definitions(sink_defs_file) - - total_sinks = len(sinks) - print(f"Loaded {total_sinks} sink definitions") - - # Setup Jinja2 environment - template_env = Environment( - loader=FileSystemLoader(templates_dir), - trim_blocks=False, - lstrip_blocks=True - ) - - # Generate files - print("\nGenerating model files...") - grouped_sinks = group_sinks_by_package(sinks) - generated_files = generate_model_files(grouped_sinks, template_env, models_dir) - - print(f"\nGenerating query file...") - generate_query_file(sinks, template_env, queries_dir / "sinks.ql") - - print(f"\nCleaning up old model files...") - clean_old_model_files(models_dir, generated_files) - - print(f"\nGeneration complete!") - print(f"- Generated {len(generated_files)} model files") - print(f"- Generated 1 query file") - print(f"- Total sink definitions processed: {total_sinks}") - - -if __name__ == "__main__": - main() diff --git a/crs/codeql/scripts/get_metadata.py b/crs/codeql/scripts/get_metadata.py deleted file mode 100644 index b50256a21..000000000 --- a/crs/codeql/scripts/get_metadata.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python3 -""" -Retrieve metadata for a sink definition given its ID. -""" - -import yaml -import sys -from pathlib import Path - - -def get_sink_id(model_data): - """Generate the ID string for a sink model.""" - subtypes_str = "true" if model_data['subtypes'] else "false" - return f"Sink: {model_data['package']}; {model_data['type']}; {subtypes_str}; {model_data['name']}; {model_data['signature']}; {model_data['ext']}; {model_data['input']}; {model_data['kind']}; {model_data['provenance']}" - - -def main(): - if len(sys.argv) != 2: - print("Usage: python3 scripts/get_metadata.py ", file=sys.stderr) - sys.exit(1) - - target_id = sys.argv[1] - - # Setup paths - script_dir = Path(__file__).parent - repo_root = script_dir.parent - sink_defs_file = repo_root / "sink_definitions.yml" - - # Load sink definitions - try: - with open(sink_defs_file, 'r') as f: - data = yaml.safe_load(f) - except FileNotFoundError: - print(f"Error: Sink definitions file '{sink_defs_file}' not found", file=sys.stderr) - sys.exit(1) - except yaml.YAMLError as e: - print(f"Error parsing YAML file '{sink_defs_file}': {e}", file=sys.stderr) - sys.exit(1) - - # Find matching sink definition - for sink_def in data['sink_definitions']: - sink_id = get_sink_id(sink_def['model']) - if sink_id == target_id: - # Output the metadata as YAML - yaml.dump(sink_def['metadata'], sys.stdout, default_flow_style=False) - return - - # If we get here, no matching sink was found - print(f"Error: No sink definition found for ID: {target_id}", file=sys.stderr) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/crs/codeql/scripts/transform_results.py b/crs/codeql/scripts/transform_results.py deleted file mode 100644 index 9071c2745..000000000 --- a/crs/codeql/scripts/transform_results.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -""" -Transform CodeQL JSON results from the current format to coordinate format. -""" - -import json -import sys -import os -import re -from pathlib import Path - - -def extract_filename_from_path(file_path): - """Extract just the filename from a full file path.""" - return os.path.basename(file_path) - - -def convert_class_name_to_jvm_format(class_name): - """Convert class name from dot notation to JVM slash notation.""" - return class_name.replace('.', '/') - - -def transform_codeql_results(input_file, output_file): - """Transform CodeQL JSON results to coordinate format.""" - - # Read the input JSON - with open(input_file, 'r') as f: - data = json.load(f) - - # Extract the tuples from the CodeQL result format - if '#select' not in data or 'tuples' not in data['#select']: - raise ValueError("Invalid CodeQL JSON format - missing #select.tuples") - - tuples = data['#select']['tuples'] - columns = data['#select']['columns'] - - # Create column index mapping - col_map = {col['name']: idx for idx, col in enumerate(columns) if 'name' in col} - col_map['entity'] = 0 # Entity is always first column - - # Transform each tuple to coordinate format - coordinates = [] - - for tuple_data in tuples: - try: - # Extract data from tuple - entity = tuple_data[col_map['entity']] - sink_type = tuple_data[col_map['sink_type']] - has_non_constant_args = tuple_data[col_map['has_non_constant_args']] - class_name = tuple_data[col_map['class_name']] - method_name = tuple_data[col_map['method_name']] - method_signature = tuple_data[col_map['method_signature']] - method_descriptor = tuple_data[col_map['method_descriptor']] - file_path = tuple_data[col_map['file_path']] - line_number = tuple_data[col_map['line_number']] - model_info = tuple_data[col_map['model_info']] - - if not has_non_constant_args: - continue - - # Extract filename from path - file_name = file_path # extract_filename_from_path(file_path) - - # Convert class name to JVM format - jvm_class_name = convert_class_name_to_jvm_format(class_name) - - # Map sink type to mark description - mark_desc = sink_type - - # Create coordinate entry - coord_entry = { - "coord": { - "line_num": line_number, - "method_name": method_name, - "file_name": file_name, - "bytecode_offset": -1, - "method_desc": method_descriptor, - "mark_desc": mark_desc, - "method_signature": method_signature, - "class_name": jvm_class_name - }, - "id": model_info - } - - coordinates.append(coord_entry) - - except (KeyError, IndexError) as e: - print(f"Warning: Skipping malformed tuple: {e}", file=sys.stderr) - continue - - # Write the transformed results - with open(output_file, 'w') as f: - json.dump(coordinates, f, indent=2) - - print(f"Transformed {len(coordinates)} entries from {input_file} to {output_file}") - - -def main(): - if len(sys.argv) != 3: - print("Usage: python3 transform_results.py ") - print("Example: python3 transform_results.py out.json transformed_results.json") - sys.exit(1) - - input_file = sys.argv[1] - output_file = sys.argv[2] - - if not os.path.exists(input_file): - print(f"Error: Input file '{input_file}' does not exist") - sys.exit(1) - - try: - transform_codeql_results(input_file, output_file) - except Exception as e: - print(f"Error: {e}", file=sys.stderr) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/crs/codeql/sinks-pack/.gitignore b/crs/codeql/sink-queries/.gitignore similarity index 100% rename from crs/codeql/sinks-pack/.gitignore rename to crs/codeql/sink-queries/.gitignore diff --git a/crs/codeql/sinks-pack/qlpack.yml b/crs/codeql/sink-queries/qlpack.yml similarity index 53% rename from crs/codeql/sinks-pack/qlpack.yml rename to crs/codeql/sink-queries/qlpack.yml index add28dd01..4f3907389 100644 --- a/crs/codeql/sinks-pack/qlpack.yml +++ b/crs/codeql/sink-queries/qlpack.yml @@ -1,10 +1,8 @@ --- library: false -name: sinks-query-pack +name: sink-queries version: 1.0.0 dependencies: - codeql/java-all: "*" -dataExtensions: - - models/*.yml + codeql/java-all: "7.3.0" defaultSuite: - query: queries/sinks.ql diff --git a/crs/codeql/sink-queries/queries/CWE-022-path-traversal.ql b/crs/codeql/sink-queries/queries/CWE-022-path-traversal.ql new file mode 100644 index 000000000..81a1a804e --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-022-path-traversal.ql @@ -0,0 +1,94 @@ +/** + * @name Path traversal sinks + * @description Reports all path traversal sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 7.5 + * @precision high + * @id java/cwe-022-path-traversal + * @tags security + * external/cwe/cwe-022 + * external/cwe/cwe-023 + * external/cwe/cwe-036 + * external/cwe/cwe-073 + */ + +import java +import semmle.code.java.dataflow.DataFlow +import semmle.code.java.security.TaintedPathQuery + +/** + * Gets the actual declaring type, traversing up from anonymous classes. + */ +RefType getActualDeclaringType(Callable c) { + exists(RefType declType | declType = c.getDeclaringType() | + if declType instanceof AnonymousClass then + result = declType.getEnclosingType() + else + result = declType + ) +} + +/** + * A call that creates a temp file with constant arguments (safe from path traversal). + */ +predicate isSafeFileCreation(Expr e) { + exists(MethodCall mc | + mc = e and + ( + // Files.createTempFile with constant arguments + mc.getMethod().hasQualifiedName("java.nio.file", "Files", "createTempFile") and + forall(Expr arg | arg = mc.getAnArgument() and arg.getType().(RefType).hasQualifiedName("java.lang", "String") | + arg instanceof CompileTimeConstantExpr + ) + or + // File.createTempFile with constant arguments + mc.getMethod().hasQualifiedName("java.io", "File", "createTempFile") and + forall(Expr arg | arg = mc.getAnArgument() and arg.getType().(RefType).hasQualifiedName("java.lang", "String") | + arg instanceof CompileTimeConstantExpr + ) + ) + ) +} + +/** + * Check if the expression flows from a safe file creation. + */ +predicate flowsFromSafeFileCreation(Expr e) { + // Direct case: the expression itself is a safe file creation + isSafeFileCreation(e) + or + // Variable case: check if a variable was assigned from safe file creation + exists(Variable v, VarAccess va | + va = e and + v = va.getVariable() and + flowsFromSafeFileCreation(v.getAnAssignedValue()) + ) + or + // Method call on safe file: e.g., tmpFile.toFile() where tmpFile is safe + exists(MethodCall mc | + mc = e and + flowsFromSafeFileCreation(mc.getQualifier()) + ) +} + +from TaintedPathSink sink, Call call, Expr arg, string flag +where + sink.asExpr() = [call.getAnArgument(), call.getQualifier()] and + arg = sink.asExpr() and + if arg instanceof CompileTimeConstantExpr then + flag = "[FILTERED: compile-time constant]" + else if arg instanceof NullLiteral then + flag = "[FILTERED: null literal]" + else if arg.getType() instanceof PrimitiveType then + flag = "[FILTERED: primitive type]" + else if arg.getType().(RefType).hasQualifiedName("java.lang", ["Integer", "Long", "Boolean", "Double", "Float", "Short", "Byte"]) then + flag = "[FILTERED: boxed primitive]" + else if flowsFromSafeFileCreation(arg) then + flag = "[FILTERED: safe file creation]" + else + flag = "" +select call, + "[" + getActualDeclaringType(call.getEnclosingCallable()).getQualifiedName() + + ", " + call.getEnclosingCallable().getName() + + "] Path traversal sink " + flag diff --git a/crs/codeql/sink-queries/queries/CWE-078-command-injection.ql b/crs/codeql/sink-queries/queries/CWE-078-command-injection.ql new file mode 100644 index 000000000..31f2df25f --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-078-command-injection.ql @@ -0,0 +1,57 @@ +/** + * @name Command injection sinks + * @description Reports all command injection sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 9.8 + * @precision high + * @id java/cwe-078-command-injection + * @tags security + * external/cwe/cwe-078 + * external/cwe/cwe-088 + */ + +import java +import semmle.code.java.dataflow.DataFlow +import semmle.code.java.dataflow.TaintTracking +import semmle.code.java.security.CommandLineQuery +import semmle.code.java.security.ExternalProcess + +// Global dataflow configuration to track from command injection sinks to their usage +module CommandInjectionUsageConfig implements DataFlow::ConfigSig { + predicate isSource(DataFlow::Node source) { + exists(CommandInjectionSink sink, Call sinkCall | + ( + sink.asExpr() = sinkCall.getAnArgument() + or + sink.(DataFlow::ImplicitVarargsArray).getCall() = sinkCall + ) and + source.asExpr() = sinkCall + ) + } + + predicate isSink(DataFlow::Node sink) { + // Any call where the source flows to the qualifier or an argument + exists(Call c | sink.asExpr() = [c.getQualifier(), c.getAnArgument()]) + } +} + +module CommandInjectionUsageFlow = DataFlow::Global; + +from Call resultCall +where + // Case 1: The original sink call itself (e.g., Runtime.exec(cmd)) + exists(CommandInjectionSink sink | + sink.asExpr() = resultCall.getAnArgument() + or + sink.(DataFlow::ImplicitVarargsArray).getCall() = resultCall + ) + or + // Case 2: Subsequent usage where the sink flows to (e.g., procBuilder.start()) + CommandInjectionUsageFlow::flow(_, DataFlow::exprNode(resultCall.getQualifier())) + or + CommandInjectionUsageFlow::flow(_, DataFlow::exprNode(resultCall.getAnArgument())) +select resultCall, + "[" + resultCall.getEnclosingCallable().getDeclaringType().getQualifiedName() + + ", " + resultCall.getEnclosingCallable().getName() + + "] Command injection sink" diff --git a/crs/codeql/sink-queries/queries/CWE-089-sql-injection.ql b/crs/codeql/sink-queries/queries/CWE-089-sql-injection.ql new file mode 100644 index 000000000..f50b25f36 --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-089-sql-injection.ql @@ -0,0 +1,21 @@ +/** + * @name SQL injection sinks + * @description Reports all SQL injection sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 8.8 + * @precision high + * @id java/cwe-089-sql-injection + * @tags security + * external/cwe/cwe-089 + */ + +import java +import semmle.code.java.security.QueryInjection + +from QueryInjectionSink sink, Call call +where sink.asExpr() = call.getAnArgument() +select sink, + "[" + sink.asExpr().getEnclosingCallable().getDeclaringType().getQualifiedName() + + ", " + sink.asExpr().getEnclosingCallable().getName() + + "] SQL injection sink" diff --git a/crs/codeql/sink-queries/queries/CWE-090-ldap-injection.ql b/crs/codeql/sink-queries/queries/CWE-090-ldap-injection.ql new file mode 100644 index 000000000..2cc2024e6 --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-090-ldap-injection.ql @@ -0,0 +1,21 @@ +/** + * @name LDAP injection sinks + * @description Reports all LDAP injection sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 9.8 + * @precision high + * @id java/cwe-090-ldap-injection + * @tags security + * external/cwe/cwe-090 + */ + +import java +import semmle.code.java.security.LdapInjection + +from LdapInjectionSink sink, Call call +where sink.asExpr() = call.getAnArgument() +select call, + "[" + call.getEnclosingCallable().getDeclaringType().getQualifiedName() + + ", " + call.getEnclosingCallable().getName() + + "] LDAP injection sink" diff --git a/crs/codeql/sink-queries/queries/CWE-094-script-injection.ql b/crs/codeql/sink-queries/queries/CWE-094-script-injection.ql new file mode 100644 index 000000000..a378a1392 --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-094-script-injection.ql @@ -0,0 +1,34 @@ +/** + * @name Script injection sinks (javax.script.ScriptEngine) + * @description Reports all script engine injection sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 9.3 + * @precision high + * @id java/cwe-094-script-injection + * @tags security + * external/cwe/cwe-094 + */ + +import java + +/** Copied from codeql's CWE-094/ScriptInjection.ql */ +class ScriptEngineMethod extends Method { + ScriptEngineMethod() { + this.getDeclaringType().getAnAncestor().hasQualifiedName("javax.script", "ScriptEngine") and + this.hasName("eval") + or + this.getDeclaringType().getAnAncestor().hasQualifiedName("javax.script", "Compilable") and + this.hasName("compile") + or + this.getDeclaringType().getAnAncestor().hasQualifiedName("javax.script", "ScriptEngineFactory") and + this.hasName(["getProgram", "getMethodCallSyntax"]) + } +} + +from MethodCall ma +where ma.getMethod() instanceof ScriptEngineMethod +select ma, + "[" + ma.getEnclosingCallable().getDeclaringType().getQualifiedName() + + ", " + ma.getEnclosingCallable().getName() + + "] Script engine injection sink: " + ma.getMethod().getName() diff --git a/crs/codeql/sink-queries/queries/CWE-117-log-injection.ql b/crs/codeql/sink-queries/queries/CWE-117-log-injection.ql new file mode 100644 index 000000000..8688c3dfc --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-117-log-injection.ql @@ -0,0 +1,29 @@ +/** + * @name Log injection sinks + * @description Reports all logging sinks with user-controlled input without dataflow tracking. Includes Log4j (Log4Shell), SLF4J, java.util.logging, etc. + * @kind problem + * @problem.severity error + * @security-severity 7.5 + * @precision high + * @id java/cwe-117-log-injection + * @tags security + * external/cwe/cwe-117 + * external/cwe/cwe-020 + */ + +import java +import semmle.code.java.security.LogInjectionQuery +import semmle.code.java.security.LogInjection + +from LogInjectionSink sink, Call call, string flag +where + sink.asExpr() = call.getAnArgument() and + // Check if there's flow from a source + if not LogInjectionFlow::flow(_, sink) then + flag = "[FILTERED: no flow from source]" + else + flag = "" +select call, + "[" + call.getEnclosingCallable().getDeclaringType().getQualifiedName() + + ", " + call.getEnclosingCallable().getName() + + "] Log injection sink " + flag diff --git a/crs/codeql/sink-queries/queries/CWE-470-unsafe-reflection.ql b/crs/codeql/sink-queries/queries/CWE-470-unsafe-reflection.ql new file mode 100644 index 000000000..e58cc4f7f --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-470-unsafe-reflection.ql @@ -0,0 +1,85 @@ +/** + * @name Unsafe reflection sinks + * @description Reports all reflection invocation sinks (Method.invoke and Constructor.newInstance) without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 9.0 + * @precision high + * @id java/cwe-470-unsafe-reflection + * @tags security + * external/cwe/cwe-470 + */ + +import java +import semmle.code.java.dataflow.DataFlow + +/** + * A call to `java.lang.reflect.Method.invoke`. + */ +class MethodInvokeCall extends MethodCall { + MethodInvokeCall() { this.getMethod().hasQualifiedName("java.lang.reflect", "Method", "invoke") } +} + +/** + * An expression that represents a hardcoded class (class literal or Class.forName with constant). + */ +predicate isHardcodedClass(Expr e) { + // Class literal: Foo.class + e instanceof TypeLiteral + or + // Class.forName("...") with constant first argument (class name) + exists(MethodCall forName | + forName.getMethod().hasQualifiedName("java.lang", "Class", "forName") and + forName.getArgument(0) instanceof CompileTimeConstantExpr and + e = forName + ) +} + +/** + * Check if the Constructor/Method was obtained from a hardcoded class. + */ +predicate flowsFromHardcodedClass(Expr qualifier) { + // The qualifier is the Constructor or Method object + // We need to find what it flows from + exists(MethodCall getReflectionObject | + // The qualifier flows from a call like getDeclaredConstructor() or getMethod() + DataFlow::localExprFlow(getReflectionObject, qualifier) and + // That call was made on a Class object + exists(Expr classExpr | + classExpr = getReflectionObject.getQualifier() and + // That Class object flows from a hardcoded class + exists(Expr source | + isHardcodedClass(source) and + DataFlow::localExprFlow(source, classExpr) + ) + ) + ) + or + // Direct case: qualifier itself is from a hardcoded class + exists(Expr source | + isHardcodedClass(source) and + DataFlow::localExprFlow(source, qualifier) + ) +} + +from MethodCall ma, string flag +where + ( + ma.getMethod().getDeclaringType().getSourceDeclaration().hasQualifiedName("java.lang.reflect", "Constructor") and + ma.getMethod().hasName("newInstance") + or + ma instanceof MethodInvokeCall + ) and + // Filter out various safe cases + if ma.getQualifier() instanceof CompileTimeConstantExpr then + flag = "[FILTERED: compile-time constant]" + else if ma.getQualifier() instanceof NullLiteral then + flag = "[FILTERED: null literal]" + else if flowsFromHardcodedClass(ma.getQualifier()) then + flag = "[FILTERED: hardcoded class]" + else + flag = "" +select ma, + "[" + ma.getEnclosingCallable().getDeclaringType().getQualifiedName() + + ", " + ma.getEnclosingCallable().getName() + + "] Unsafe reflection call: " + ma.getMethod().getName() + " " + flag diff --git a/crs/codeql/sink-queries/queries/CWE-502-unsafe-deserialization.ql b/crs/codeql/sink-queries/queries/CWE-502-unsafe-deserialization.ql new file mode 100644 index 000000000..1022368b9 --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-502-unsafe-deserialization.ql @@ -0,0 +1,51 @@ +/** + * @name Unsafe deserialization sinks (all sources) + * @description Reports all unsafe deserialization sinks from both hardcoded patterns and external models without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 9.8 + * @precision high + * @id java/cwe-502-unsafe-deserialization + * @tags security + * external/cwe/cwe-502 + */ + +import java +import semmle.code.java.dataflow.DataFlow +import semmle.code.java.dataflow.ExternalFlow +import semmle.code.java.security.UnsafeDeserializationQuery + +/** + * Gets the actual declaring type, traversing up from anonymous classes. + */ +RefType getActualDeclaringType(Callable c) { + exists(RefType declType | declType = c.getDeclaringType() | + if declType instanceof AnonymousClass then + result = declType.getEnclosingType() + else + result = declType + ) +} + +// Models-as-data sinks +class DefaultUnsafeDeserializationSink extends DataFlow::Node { + DefaultUnsafeDeserializationSink() { sinkNode(this, "unsafe-deserialization") } +} + +from Expr sink, string msg +where + // Hardcoded framework patterns (ObjectInputStream, XStream, Kryo, Jackson, etc.) + ( + unsafeDeserialization(_, sink) and + msg = "Unsafe deserialization sink (hardcoded pattern)" + ) + or + // External models (models-as-data) + ( + sink = any(DefaultUnsafeDeserializationSink s).asExpr() and + msg = "Unsafe deserialization sink (external model)" + ) +select sink, + "[" + getActualDeclaringType(sink.getEnclosingCallable()).getQualifiedName() + + ", " + sink.getEnclosingCallable().getName() + + "] " + msg diff --git a/crs/codeql/sink-queries/queries/CWE-611-xxe.ql b/crs/codeql/sink-queries/queries/CWE-611-xxe.ql new file mode 100644 index 000000000..74e9ecabc --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-611-xxe.ql @@ -0,0 +1,21 @@ +/** + * @name XXE (XML External Entity) sinks + * @description Reports all XXE sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 9.1 + * @precision high + * @id java/cwe-611-xxe + * @tags security + * external/cwe/cwe-611 + */ + +import java +import semmle.code.java.security.XmlParsers + +from XmlParserCall parse +where not parse.isSafe() +select parse, + "[" + parse.getEnclosingCallable().getDeclaringType().getQualifiedName() + + ", " + parse.getEnclosingCallable().getName() + + "] XXE sink: unsafe XML parsing" diff --git a/crs/codeql/sink-queries/queries/CWE-643-xpath-injection.ql b/crs/codeql/sink-queries/queries/CWE-643-xpath-injection.ql new file mode 100644 index 000000000..e1fb90f66 --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-643-xpath-injection.ql @@ -0,0 +1,20 @@ +/** + * @name XPath injection sinks + * @description Reports all XPath injection sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 9.8 + * @precision high + * @id java/cwe-643-xpath-injection + * @tags security + * external/cwe/cwe-643 + */ + +import java +import semmle.code.java.security.XPath + +from XPathInjectionSink sink +select sink, + "[" + sink.asExpr().getEnclosingCallable().getDeclaringType().getQualifiedName() + + ", " + sink.asExpr().getEnclosingCallable().getName() + + "] XPath injection sink" diff --git a/crs/codeql/sink-queries/queries/CWE-730-regex-injection.ql b/crs/codeql/sink-queries/queries/CWE-730-regex-injection.ql new file mode 100644 index 000000000..cc9fe0ca3 --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-730-regex-injection.ql @@ -0,0 +1,47 @@ +/** + * @name Regex injection sinks + * @description Reports all regex injection sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 7.5 + * @precision high + * @id java/cwe-730-regex-injection + * @tags security + * external/cwe/cwe-1333 + * external/cwe/cwe-730 + * external/cwe/cwe-400 + */ + +import java +import semmle.code.java.security.regexp.RegexInjection + +/** + * Gets the actual declaring type, traversing up from anonymous classes. + */ +RefType getActualDeclaringType(Callable c) { + exists(RefType declType | declType = c.getDeclaringType() | + if declType instanceof AnonymousClass then + result = declType.getEnclosingType() + else + result = declType + ) +} + +from RegexInjectionSink sink, Call call, Expr arg, string flag +where + sink.asExpr() = call.getAnArgument() and + arg = sink.asExpr() and + if arg instanceof CompileTimeConstantExpr then + flag = "[FILTERED: compile-time constant]" + else if arg instanceof NullLiteral then + flag = "[FILTERED: null literal]" + else if arg.getType() instanceof PrimitiveType then + flag = "[FILTERED: primitive type]" + else if arg.getType().(RefType).hasQualifiedName("java.lang", ["Integer", "Long", "Boolean", "Double", "Float", "Short", "Byte"]) then + flag = "[FILTERED: boxed primitive]" + else + flag = "" +select call, + "[" + getActualDeclaringType(sink.asExpr().getEnclosingCallable()).getQualifiedName() + + ", " + call.getEnclosingCallable().getName() + + "] Regex injection sink " + flag diff --git a/crs/codeql/sink-queries/queries/CWE-918-ssrf.ql b/crs/codeql/sink-queries/queries/CWE-918-ssrf.ql new file mode 100644 index 000000000..f21df798f --- /dev/null +++ b/crs/codeql/sink-queries/queries/CWE-918-ssrf.ql @@ -0,0 +1,54 @@ +/** + * @name Server-side request forgery sinks + * @description Reports all SSRF sinks without dataflow tracking. + * @kind problem + * @problem.severity error + * @security-severity 9.1 + * @precision high + * @id java/cwe-918-ssrf + * @tags security + * external/cwe/cwe-918 + */ + +import java +import semmle.code.java.dataflow.DataFlow +import semmle.code.java.security.RequestForgery + +/** + * Gets the actual declaring type, traversing up from anonymous classes. + */ +RefType getActualDeclaringType(Callable c) { + exists(RefType declType | declType = c.getDeclaringType() | + if declType instanceof AnonymousClass then + result = declType.getEnclosingType() + else + result = declType + ) +} + +from RequestForgerySink sink, Call sinkCall, Call resultCall, Expr arg, string flag +where + // Find the call that contains the sink + sink.asExpr() = [sinkCall.getAnArgument(), sinkCall.getQualifier()] and + arg = sink.asExpr() and + ( + // Case 1: Direct chained call - new URL(x).openStream() + resultCall = sinkCall + or + // Case 2: Separate usage - URLConnection urlC = url.openConnection(); urlC.getInputStream() + DataFlow::localExprFlow(sinkCall, resultCall.getQualifier()) + ) and + if arg instanceof CompileTimeConstantExpr then + flag = "[FILTERED: compile-time constant]" + else if arg instanceof NullLiteral then + flag = "[FILTERED: null literal]" + else if arg.getType() instanceof PrimitiveType then + flag = "[FILTERED: primitive type]" + else if arg.getType().(RefType).hasQualifiedName("java.lang", ["Integer", "Long", "Boolean", "Double", "Float", "Short", "Byte"]) then + flag = "[FILTERED: boxed primitive]" + else + flag = "" +select resultCall, + "[" + getActualDeclaringType(resultCall.getEnclosingCallable()).getQualifiedName() + + ", " + resultCall.getEnclosingCallable().getName() + + "] Server-side request forgery sink " + flag diff --git a/crs/codeql/sink_definitions.yml b/crs/codeql/sink_definitions.yml deleted file mode 100644 index ffe3abf9d..000000000 --- a/crs/codeql/sink_definitions.yml +++ /dev/null @@ -1,456 +0,0 @@ -sink_definitions: - - model: - package: "java.math" - type: "BigDecimal" - subtypes: false - name: "BigDecimal" - signature: "(String)" - ext: "" - input: "Argument[0]" - kind: "sink-BigDecimal" - provenance: "manual" - metadata: - description: "DoS for BigDecimal" - - model: - package: "java.math" - type: "BigDecimal" - subtypes: false - name: "BigDecimal" - signature: "(String,MathContext)" - ext: "" - input: "Argument[0]" - kind: "sink-BigDecimal" - provenance: "manual" - metadata: - description: "DoS for BigDecimal" - - model: - package: "java.math" - type: "BigDecimal" - subtypes: false - name: "BigDecimal" - signature: "(char[])" - ext: "" - input: "Argument[0]" - kind: "sink-BigDecimal" - provenance: "manual" - metadata: - description: "DoS for BigDecimal" - - model: - package: "java.math" - type: "BigDecimal" - subtypes: false - name: "BigDecimal" - signature: "(char[],int,int)" - ext: "" - input: "Argument[0]" - kind: "sink-BigDecimal" - provenance: "manual" - metadata: - description: "DoS for BigDecimal" - - model: - package: "java.math" - type: "BigDecimal" - subtypes: false - name: "BigDecimal" - signature: "(char[],int,int,MathContext)" - ext: "" - input: "Argument[0]" - kind: "sink-BigDecimal" - provenance: "manual" - metadata: - description: "DoS for BigDecimal" - - model: - package: "java.math" - type: "BigDecimal" - subtypes: false - name: "BigDecimal" - signature: "(char[],MathContext)" - ext: "" - input: "Argument[0]" - kind: "sink-BigDecimal" - provenance: "manual" - metadata: - description: "DoS for BigDecimal" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(File,DefaultHandler)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(File,HandlerBase)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(InputSource,DefaultHandler)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(InputSource,HandlerBase)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(InputStream,DefaultHandler)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(InputStream,DefaultHandler,String)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(InputStream,HandlerBase)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(InputStream,HandlerBase,String)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(String,DefaultHandler)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - model: - package: "javax.xml.parsers" - type: "SAXParser" - subtypes: false - name: "parse" - signature: "(String,HandlerBase)" - ext: "" - input: "Argument[0]" - kind: "sink-SAXParser" - provenance: "manual" - metadata: - description: "SSRF by SAXParser" - - # java.net.URL - - model: # URL(String spec) - package: "java.net" - type: "URL" - subtypes: false - name: "URL" - signature: "(String)" - ext: "" - input: "Argument[0]" # spec - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URL" - - model: # URL(URL context, String spec) - package: "java.net" - type: "URL" - subtypes: false - name: "URL" - signature: "(URL,String)" - ext: "" - input: "Argument[1]" # spec - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URL" - - model: # URL(URL context, String spec, URLStreamHandler handler) - package: "java.net" - type: "URL" - subtypes: false - name: "URL" - signature: "(URL,String,URLStreamHandler)" - ext: "" - input: "Argument[1]" # spec - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URL" - - model: # URL(String protocol, String host, String file) - package: "java.net" - type: "URL" - subtypes: false - name: "URL" - signature: "(String,String,String)" - ext: "" - input: "Argument[1]" # host - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URL" - - model: # URL(String protocol, String host, int port, String file) - package: "java.net" - type: "URL" - subtypes: false - name: "URL" - signature: "(String,String,int,String)" - ext: "" - input: "Argument[1]" # host - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URL" - - model: # URL(String protocol, String host, int port, String file, URLStreamHandler handler) - package: "java.net" - type: "URL" - subtypes: false - name: "URL" - signature: "(String,String,int,String,URLStreamHandler)" - ext: "" - input: "Argument[1]" # host - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URL" - - model: # URL.of(URI uri, URLStreamHandler handler) - # Note that this sink is only available in JDK 20+ - package: "java.net" - type: "URL" - subtypes: false - name: "of" - signature: "(URI,URLStreamHandler)" - ext: "" - input: "Argument[0]" # uri - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URL" - # NOTE: URL.fabricateNewURL() would create new instance but it's hard to track - - # java.net.URI - - model: # URI(String str) - package: "java.net" - type: "URI" - subtypes: false - name: "URI" - signature: "(String)" - ext: "" - input: "Argument[0]" # str - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URI" - - model: # URI(String scheme, String userInfo, String host, int port, String path, String query, String fragment) - package: "java.net" - type: "URI" - subtypes: false - name: "URI" - signature: "(String,String,String,int,String,String,String)" - ext: "" - input: "Argument[2]" # host - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URI" - - model: # URI(String scheme, String authority, String path, String query, String fragment) - package: "java.net" - type: "URI" - subtypes: false - name: "URI" - signature: "(String,String,String,String,String)" - ext: "" - input: "Argument[1]" # authority - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URI" - - model: # URI(String scheme, String host, String path, String fragment) - package: "java.net" - type: "URI" - subtypes: false - name: "URI" - signature: "(String,String,String,String)" - ext: "" - input: "Argument[1]" # host - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URI" - - model: # URI(String scheme, String ssp, String fragment) - package: "java.net" - type: "URI" - subtypes: false - name: "URI" - signature: "(String,String,String)" - ext: "" - input: "Argument[1]" # ssp - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URI" - - model: # java.net.URI.create(String) - package: "java.net" - type: "URI" - subtypes: false - name: "create" - signature: "(String)" - ext: "" - input: "Argument[0]" # uri - kind: "sink-ServerSideRequestForgery" - provenance: "manual" - metadata: - description: "SSRF by URI" - - # javax.validation.Validator - - model: # Validator.validate(T object, Class... groups) - package: "javax.validation" - type: "Validator" - subtypes: false - name: "validate" - signature: "(Object,Class[])" - ext: "" - input: "Argument[0]" # object - kind: "sink-ExpressionLanguageInjection" - provenance: "manual" - metadata: - description: "Remote code execution by Validator" - - model: # Validator.validateProperty(T object, String propertyName, Class... groups) - package: "javax.validation" - type: "Validator" - subtypes: false - name: "validateProperty" - signature: "(Object,String,Class[])" - ext: "" - input: "Argument[0]" # object - kind: "sink-ExpressionLanguageInjection" - provenance: "manual" - metadata: - description: "Remote code execution by Validator" - - model: # validateValue(Class beanType, String propertyName, Object value, Class... groups) - # NOTE: I'm not sure if this is a vulerable sink - package: "javax.validation" - type: "Validator" - subtypes: false - name: "validateValue" - signature: "(Class,String,Object,Class[])" - ext: "" - input: "Argument[0]" # beanType - kind: "sink-ExpressionLanguageInjection" - provenance: "manual" - metadata: - description: "Remote code execution by Validator" - - # org.apache.batik - - model: # TranscoderInput(Document document) - package: "org.apache.batik.transcoder" - type: "TranscoderInput" - subtypes: false - name: "TranscoderInput" - signature: "(org.w3c.dom.Document)" - ext: "" - input: "Argument[0]" # document - kind: "sink-batik-TranscoderInput" - provenance: "manual" - metadata: - description: "SSRF by batik TranscoderInput" - - model: # TranscoderInput(java.io.Reader reader) - package: "org.apache.batik.transcoder" - type: "TranscoderInput" - subtypes: false - name: "TranscoderInput" - signature: "(Reader)" - ext: "" - input: "Argument[0]" # reader - kind: "sink-batik-TranscoderInput" - provenance: "manual" - metadata: - description: "SSRF by batik TranscoderInput" - - model: # TranscoderInput(java.lang.String uri) - package: "org.apache.batik.transcoder" - type: "TranscoderInput" - subtypes: false - name: "TranscoderInput" - signature: "(String)" - ext: "" - input: "Argument[0]" # uri - kind: "sink-batik-TranscoderInput" - provenance: "manual" - metadata: - description: "SSRF by batik TranscoderInput" - - model: # TranscoderInput(org.xml.sax.XMLReader xmlReader) - package: "org.apache.batik.transcoder" - type: "TranscoderInput" - subtypes: false - name: "TranscoderInput" - signature: "(org.xml.sax.XMLReader)" - ext: "" - input: "Argument[0]" # xmlReader - kind: "sink-batik-TranscoderInput" - provenance: "manual" - metadata: - description: "SSRF by batik TranscoderInput" - - model: # TranscoderInput(java.io.InputStream istream) - package: "org.apache.batik.transcoder" - type: "TranscoderInput" - subtypes: false - name: "TranscoderInput" - signature: "(InputStream)" - ext: "" - input: "Argument[0]" # inputStream - kind: "sink-batik-TranscoderInput" - provenance: "manual" - metadata: - description: "SSRF by batik TranscoderInput" diff --git a/crs/codeql/sinks-pack/models/.gitignore b/crs/codeql/sinks-pack/models/.gitignore deleted file mode 100644 index 1cda54be9..000000000 --- a/crs/codeql/sinks-pack/models/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.yml diff --git a/crs/codeql/sinks-pack/queries/.gitignore b/crs/codeql/sinks-pack/queries/.gitignore deleted file mode 100644 index f47ccd43d..000000000 --- a/crs/codeql/sinks-pack/queries/.gitignore +++ /dev/null @@ -1 +0,0 @@ -sinks.ql diff --git a/crs/codeql/templates/model.yml.j2 b/crs/codeql/templates/model.yml.j2 deleted file mode 100644 index 64e53f218..000000000 --- a/crs/codeql/templates/model.yml.j2 +++ /dev/null @@ -1,8 +0,0 @@ -extensions: - - addsTo: - pack: codeql/java-all - extensible: sinkModel - data: -{%- for sink in sinks %} - - ["{{ sink.package }}", "{{ sink.class_name }}", {{ sink.subtypes | lower }}, "{{ sink.name }}", "{{ sink.signature }}", "{{ sink.ext }}", "{{ sink.input_arg }}", "{{ sink.kind }}", "{{ sink.provenance }}"] -{%- endfor %} diff --git a/crs/codeql/templates/sinks.ql.j2 b/crs/codeql/templates/sinks.ql.j2 deleted file mode 100644 index 80ee80cc2..000000000 --- a/crs/codeql/templates/sinks.ql.j2 +++ /dev/null @@ -1,153 +0,0 @@ -/** - * @name API Sinks Analysis - * @description Identifies API sinks and checks for non-constant arguments or receivers - * @kind problem - * @problem.severity warning - * @id java/api-sinks-analysis - * @tags security - */ - -import java -import semmle.code.java.dataflow.DataFlow -import semmle.code.java.dataflow.ExternalFlow -import semmle.code.java.dataflow.internal.FlowSummaryImpl - -/** - * A sink for various types of API calls that can lead to security vulnerabilities. - * This includes XPath, command, environment, JNDI, LDAP, path injections, - * regex uses, request forgery, and SQL injection. - */ -class ApiSink extends DataFlow::Node { - private string sinkType; - - ApiSink() { -{%- for sink_type in sink_types %} - sinkType = "{{ sink_type }}" and sinkNode(this, sinkType){% if not loop.last %} or{% endif %} -{%- endfor %} - } - - /** - * Gets the sink type for this API sink. - */ - string getSinkType() { - result = sinkType - } - - /** - * Gets the fully qualified class name for this sink. - */ - string getClassFQDN() { - result = this.getEnclosingCallable().getDeclaringType().getQualifiedName() - } - - /** - * Gets the method name for this sink. - */ - string getMethodName() { - result = this.getEnclosingCallable().getName() - } - - /** - * Gets the method signature for better identification. - */ - string getMethodSignature() { - exists(Callable c | c = this.getEnclosingCallable() | - if exists(c.getSignature()) - then - exists(Type returnType | returnType = c.getReturnType() | - if returnType instanceof RefType - then result = returnType.(RefType).getSourceDeclaration().getQualifiedName() + " " + c.getSignature() - else result = returnType.toString() + " " + c.getSignature() - ) - else result = "UNKNOWN" - ) - } - - /** - * Gets the method descriptor for better identification. - */ - string getMethodDescriptor() { - result = this.getEnclosingCallable().getMethodDescriptor() - } - - /** - * Checks if this sink has at least one non-constant argument or a non-constant receiver. - */ - predicate hasNonConstantArgumentOrReceiver() { - exists(Expr expr | expr = this.asExpr() | - // Handle method calls - exists(MethodCall mc | mc = expr | - // Check if any argument is non-constant - exists(Expr arg | arg = mc.getAnArgument() | not arg.isCompileTimeConstant()) or - // Check if receiver/qualifier is non-constant (and exists) - exists(Expr qualifier | qualifier = mc.getQualifier() | not qualifier.isCompileTimeConstant()) - ) or - // Handle constructor calls - exists(ClassInstanceExpr cie | cie = expr | - // Check if any argument is non-constant - exists(Expr arg | arg = cie.getAnArgument() | not arg.isCompileTimeConstant()) - ) or - // Handle field access - exists(FieldAccess fa | fa = expr | - // Check if qualifier is non-constant - exists(Expr qualifier | qualifier = fa.getQualifier() | not qualifier.isCompileTimeConstant()) - ) or - // Handle array access - exists(ArrayAccess aa | aa = expr | - // Check if array or index is non-constant - not aa.getArray().isCompileTimeConstant() or - not aa.getIndexExpr().isCompileTimeConstant() - ) or - // Handle other expressions - if the expression itself is non-constant - (not expr.isCompileTimeConstant() and - not expr instanceof MethodCall and - not expr instanceof ClassInstanceExpr and - not expr instanceof FieldAccess and - not expr instanceof ArrayAccess) - ) - } - - /** - * Returns true if this sink has non-constant arguments or receiver, false otherwise. - */ - boolean hasNonConstantArgumentOrReceiverBoolean() { - if this.hasNonConstantArgumentOrReceiver() - then result = true - else result = false - } - - /** - * Gets the sink definition info of this sink. - */ - QlBuiltins::ExtensionId getMadId() { - exists( - string model, string namespace, string type, boolean subtypes, string name, string signature, string ext, - string originalInput, string provenance, QlBuiltins::ExtensionId madId - | - sinkNode(this, sinkType, model) and - sinkModel(namespace, type, subtypes, name, signature, ext, originalInput, sinkType, provenance, madId) and - model = "MaD:" + madId.toString() and - result = madId - ) - } - - string getModelInfo() { - exists(string modelInfo | interpretModelForTest(this.getMadId(), modelInfo) | - result = modelInfo - ) - } -} - - - -from ApiSink sink -select sink.getLocation(), - sink.getSinkType() as sink_type, - sink.hasNonConstantArgumentOrReceiverBoolean() as has_non_constant_args, - sink.getClassFQDN() as class_name, - sink.getMethodName() as method_name, - sink.getMethodSignature() as method_signature, - sink.getMethodDescriptor() as method_descriptor, - sink.getLocation().getFile().getBaseName() as file_path, - sink.getLocation().getStartLine() as line_number, - sink.getModelInfo() as model_info \ No newline at end of file diff --git a/crs/codeql/transform_results.py b/crs/codeql/transform_results.py new file mode 100644 index 000000000..bfc1615f0 --- /dev/null +++ b/crs/codeql/transform_results.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +Transform CodeQL SARIF results from standard security queries to coordinate format. +""" + +import json +import sys +import os +from pathlib import Path + + +def deduplicate_sinkpoints(coordinates): + """Deduplicate sinkpoints, keeping only those with superset column ranges. + + If multiple sinkpoints have the same file, line, and CWE, only keep the one(s) + whose column range is not a proper subset of any other's column range. + """ + # Group by (file_name, line_num, cwe) + groups = {} + for coord in coordinates: + key = (coord['coord']['file_name'], coord['coord']['line_num'], coord['cwe']) + if key not in groups: + groups[key] = [] + groups[key].append(coord) + + # For each group, keep only entries with maximal column ranges + result = [] + for key, group in groups.items(): + if len(group) == 1: + result.append(group[0]) + else: + # Find entries whose column range is not a proper subset of any other + for entry1 in group: + start1 = entry1['coord']['start_column'] + end1 = entry1['coord']['end_column'] + + is_maximal = True + for entry2 in group: + if entry1 is entry2: + continue + start2 = entry2['coord']['start_column'] + end2 = entry2['coord']['end_column'] + + # Check if entry1's range is a proper subset of entry2's range + # entry1 is a proper subset of entry2 if: + # - entry2's range contains entry1's range + # - and they're not equal + if start2 <= start1 and end1 <= end2 and (start2 < start1 or end2 > end1): + is_maximal = False + break + + if is_maximal: + result.append(entry1) + + return result + + +def transform_codeql_results(input_files, output_file): + """Transform CodeQL SARIF results to coordinate format. + + Args: + input_files: Either a single SARIF file path or a list of SARIF file paths + output_file: Output JSON file path + """ + + # Read the SARIF file(s) + coordinates = [] + + # Handle both single file and list of files + if isinstance(input_files, str): + input_files = [input_files] + + for input_file in input_files: + if not os.path.exists(input_file): + print(f"Warning: Skipping non-existent file: {input_file}", file=sys.stderr) + continue + + with open(input_file, 'r') as f: + sarif_data = json.load(f) + + # Process each run in the SARIF file + for run in sarif_data.get('runs', []): + results = run.get('results', []) + + for result in results: + try: + # Extract CWE number and message text from result + rule_id = result.get('ruleId', '') + message_text = result.get('message', {}).get('text', '') + + # Check if the finding is filtered (message contains [FILTERED:...]) + filtered = "[FILTERED:" in message_text + + # Parse the class name and method name from the message. Format: + # "\\[class_name, method_name\\] message details... \\[FILTERED:...\\]" + class_name = '' + method_name = '' + if '[' in message_text and ']' in message_text: + first_bracket_close = message_text.index(']') + class_method_part = message_text[2:first_bracket_close-1] + if ',' in class_method_part: + class_name, method_name = [part.strip() for part in class_method_part.split(',', 1)] + message_text = message_text[first_bracket_close+1:].strip() + + # Extract CWE number from rule ID (e.g., "java/cwe-078-command-injection" -> "CWE-078") + cwe_number = '' + if 'cwe-' in rule_id.lower(): + parts = rule_id.lower().split('cwe-') + if len(parts) > 1: + # Extract the numeric part after 'cwe-' + cwe_part = parts[1].split('-')[0].split('/')[0] + cwe_number = f"CWE-{cwe_part}" + + # Get all locations from the result + locations = result.get('locations', []) + + for location in locations: + physical_location = location.get('physicalLocation', {}) + artifact_location = physical_location.get('artifactLocation', {}) + region = physical_location.get('region', {}) + + # Extract file path + file_path = artifact_location.get('uri', '') + if not file_path: + continue + + # Extract line and column information + start_line = region.get('startLine') + if start_line is None: + continue + + end_line = region.get('endLine', start_line) + start_column = region.get('startColumn', 1) + end_column = region.get('endColumn', start_column) + + # Create id: filename + start_line + start_column + end_line + end_column + entry_id = f"{file_path}:{start_line}:{start_column}:{end_line}:{end_column}" + + # Check if the file is a test file (contains "Test" in the path) + filtered_out_test = "Test" in file_path or "/test/" in file_path.lower() + + # Create one entry for each line from start_line to end_line + for line_num in range(start_line, end_line + 1): + coord_entry = { + "coord": { + "line_num": line_num, + "file_name": file_path, + "start_column": start_column, + "end_column": end_column, + "class_name": class_name, + "method_name": method_name + }, + "id": entry_id, + "cwe": cwe_number, + "message": message_text, + "filtered_out_flow": filtered, + "filtered_out_test": filtered_out_test + } + coordinates.append(coord_entry) + break + + except (KeyError, ValueError) as e: + print(f"Warning: Skipping malformed result: {e}", file=sys.stderr) + continue + + # Deduplicate sinkpoints by keeping only those with superset column ranges + unique_coordinates = deduplicate_sinkpoints(coordinates) + + # Write the transformed results + with open(output_file, 'w') as f: + json.dump(unique_coordinates, f, indent=2) + + print(f"Transformed {len(unique_coordinates)} unique coordinate entries from {len(input_files)} SARIF file(s) to {output_file}") + + +def main(): + if len(sys.argv) < 3: + print("Usage: python3 transform_results.py ") + print("Example: python3 transform_results.py out.sarif transformed_results.json") + print("Example: python3 transform_results.py file1.sarif file2.sarif ... output.json") + sys.exit(1) + + # Last argument is output file, everything else is input + output_file = sys.argv[-1] + input_files = sys.argv[1:-1] + + try: + transform_codeql_results(input_files, output_file) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/crs/cpmetadata.py b/crs/cpmetadata.py index cbd2cefa8..eaadfd7fe 100755 --- a/crs/cpmetadata.py +++ b/crs/cpmetadata.py @@ -7,7 +7,6 @@ "built_path": "/out", "cp_full_src": "/src", "cp_name": "aixcc/jvm/mock-java", - "sink_target_conf": "/app/crs-cp-java/sink-targets.txt", "sinkpoint_path": "/crs-workdir/worker-0/metadata/aixcc/jvm/mock-java/sinkpoints.json", "harnesses": { "OssFuzz1": { @@ -715,7 +714,6 @@ def _prepare_meta(self): self.repo_src_path = self.crs.cp.cp_src_path self.built_path = self.crs.cp.built_path self.ref_diff_path = self.crs.cp.diff_path - self.sink_target_conf = Path(os.environ["JAVA_CRS_SINK_TARGET_CONF"]) self.custom_sink_conf = self.crs.sinkmanager.get_custom_sink_conf_path() self.sinkpoint_path = self.workdir / "sinkpoints.json" self._set_full_src_dir() @@ -734,8 +732,7 @@ def _dump_meta(self): "ref_diff_path": ( str(self.ref_diff_path.resolve()) if self.ref_diff_path else "" ), - "sink_target_conf": str(self.sink_target_conf.resolve()), - "custom_sink_conf": str(self.custom_sink_conf.resolve()), + "custom_sink_conf_path": str(self.custom_sink_conf.resolve()), "sinkpoint_path": str(self.sinkpoint_path.resolve()), "built_path": str(self.built_path.resolve()), "cp_full_src": str(self.cp_full_src.resolve()), @@ -872,7 +869,6 @@ def _install_meta(self): ) os.environ["CP_SINKPOINTS_FILE"] = str(self.sinkpoint_path.resolve()) os.environ["CP_METADATA_FILE"] = str(self.meta_path.resolve()) - os.environ["CP_CUSTOM_SINK_CONF"] = str(self.custom_sink_conf.resolve()) os.environ["DEEPGEN_TASK_REQ_DIR"] = str( self.crs.deepgen.get_task_req_dir().resolve() ) diff --git a/crs/crs-java.config b/crs/crs-java.config index ace8081df..c1324233d 100644 --- a/crs/crs-java.config +++ b/crs/crs-java.config @@ -3,7 +3,7 @@ "e2e_check": false, "sync_log": false, "verbose": false, - "ssmode": false, + "ssmode": true, "modules": { "cpuallocator": { "jazzer_cpu_ratio": 0, @@ -33,10 +33,11 @@ "len_control": 0, "mem_size": 4096, "jacoco_cov_dump_period": 299, - "deepgen_consumer": true + "deepgen_consumer": true, + "wait_for_llmpocgen": false }, "atldirectedjazzer": { - "enabled": false, + "enabled": true, "keep_seed": true, "len_control": 0, "beepseed_search": true, @@ -60,6 +61,8 @@ "llmpocgen": { "enabled": true, "mode": "crs", + "models": ["claude-opus-4-20250514", "o3", "claude-sonnet-4-20250514", "gemini-2.5-pro", "gpt-4.1"], + "scan_sinks": false, "diff_max_len": 17000, "worker_num": 6 }, @@ -69,13 +72,14 @@ "verbose": false }, "staticanalysis": { - "enabled": false, + "enabled": true, "static_ana_phases": ["cha-0"], "schedule_interval": 300, "schedule_target_num": 5 }, "sinkmanager": { - "enabled": true + "enabled": true, + "allowed_sink_contributors": ["sinkdetection"] }, "sariflistener": { "enabled": false @@ -95,16 +99,18 @@ "enabled": true, "exp_time": 300, "monitor_interval": 3, - "gen_models": "gpt-5:1", - "x_models": "o4-mini:1", + "gen_model": "gemini-2.5-pro", "schedule_count": 2 }, "dictgen": { "enabled": false, "gen_models": "gpt-4o:1" }, - "codeql": { - "enabled": false + "sinkdetection": { + "enabled": true, + "gen_model": "gpt-5-nano", + "max_iterations": 15, + "max_workers": 10 }, "concolic": { "enabled": false, diff --git a/crs/expkit/expkit/agent.py b/crs/expkit/expkit/agent.py new file mode 100644 index 000000000..1a3db8997 --- /dev/null +++ b/crs/expkit/expkit/agent.py @@ -0,0 +1,741 @@ +#!/usr/bin/env python3 + +import hashlib +import fnmatch +import json +import logging +import os +import tempfile + +from pathlib import Path + +from langchain.agents import AgentExecutor, create_tool_calling_agent +from langchain.callbacks.base import BaseCallbackHandler +from langchain_litellm import ChatLiteLLM +from langchain.schema import LLMResult, AgentAction, AgentFinish +from langchain.tools import BaseTool, tool +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import Runnable +from litellm import cost_per_token, RateLimitError +from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type +from typing import Any, Dict, Type, Union +from pydantic import BaseModel, Field, field_validator + +from .beepobjs import BeepSeed +from .cpmeta import CPMetadata +from .fuzzer.jazzer import JazzerFuzzer +from .utils import CRS_ERR_LOG, CRS_WARN_LOG, get_with_model_provider + +CRS_ERR = CRS_ERR_LOG("agent") +CRS_WARN = CRS_WARN_LOG("agent") +logger = logging.getLogger(__name__) +logging.getLogger("LiteLLM").setLevel(logging.INFO) + +class HexStringInput(BaseModel): + """Hex-encoded string input.""" + data: str = Field(description="The hex encoded input (e.g., '48656c6c6f')") + + +class BytesInput(BaseModel): + """Direct bytes array input.""" + data: bytes = Field(description="The raw bytes input (not hex encoded)") + + +class PythonScriptInput(BaseModel): + """Python script that generates input.""" + code: str = Field(description="Python code that returns bytes when executed, can only contain ASCII-printable characters; last instruction or 'result' variable is returned") + + @field_validator('code') + def validate_code(cls, v): + # Ensure it's all ASCII and no null or control characters + if not all(32 <= ord(c) <= 126 or c in '\n\r\t' for c in v): + raise ValueError("Code must be ASCII and cannot contain null or control characters") + return v + +class FlexibleInput(BaseModel): + """Flexible input schema supporting multiple input types using Union.""" + input_data: Union[HexStringInput, BytesInput, PythonScriptInput] = Field( + description="Input data in one of the supported formats" + ) + + def to_bytes(self) -> bytes: + """Convert the input to bytes regardless of the input type.""" + if isinstance(self.input_data, HexStringInput): + try: + return bytes.fromhex(self.input_data.data) + except ValueError as e: + raise ValueError(f"Invalid hex input: {e}") + elif isinstance(self.input_data, BytesInput): + return self.input_data.data + elif isinstance(self.input_data, PythonScriptInput): + try: + # Execute the Python script in a restricted environment + local_vars = {} + global_vars = { + "__builtins__": { + "bytearray": bytearray, + "bytes": bytes, + "int": int, + "str": str, + "list": list, + "tuple": tuple, + "dict": dict, + "range": range, + "len": len, + "sum": sum, + "min": min, + "max": max, + "abs": abs, + "pow": pow, + "round": round, + "enumerate": enumerate, + "zip": zip, + "map": map, + "filter": filter, + "True": True, + "False": False, + "type": type, + "None": None, + "__import__": __import__, + } + } + + exec(self.input_data.code, global_vars, local_vars) + + # Look for a result variable or the last expression + if 'result' in local_vars: + result = local_vars['result'] + else: + result = list(local_vars.values())[-1] if local_vars else None + + if isinstance(result, bytes): + return result + elif isinstance(result, str): + return bytes.fromhex(result) + else: + try: + # Try to convert anything else to bytes + return bytes(result) + except Exception: + raise ValueError(f"Script must return bytes, str (hex), or something convertible to bytes, got {type(result)}") + except Exception as e: + raise ValueError(f"Error executing Python script: {e}") + else: + raise ValueError(f"Unsupported input type: {self.input_data.type}") + + +class JdbToolInput(BaseModel): + """Input schema for JDB tool.""" + input_data: Union[HexStringInput, BytesInput, PythonScriptInput] = Field( + description="Input data in one of the supported formats" + ) + break_class: str = Field(description="The class where to set the breakpoint") + break_line: int = Field(description="The line number where to set the breakpoint") + commands: str = Field(description="A list of jdb commands to execute after hitting the breakpoint") + + +@tool(args_schema=JdbToolInput) +def jdb_tool(input_data: Union[HexStringInput, BytesInput, PythonScriptInput], + break_class: str, break_line: int, commands: str) -> str: + """A tool to interact with the Java Debugger (jdb). + + Args: + input_data: The input data in one of the supported formats + break_class: The class where to set the breakpoint + break_line: The line number where to set the breakpoint + commands: A list of jdb commands to execute after hitting the breakpoint + + Returns: + The output from the jdb session + """ + try: + flexible_input = FlexibleInput(input_data=input_data) + input_bytes = flexible_input.to_bytes() + hex_input = input_bytes.hex() + # TODO(fab1ano): implement actual JDB functionality using hex_input + return f"jdb tool not implemented yet (would use hex input: {hex_input})" + except Exception as e: + return f"Error processing input: {e}" + + +class GenericPoVVerifier: + """Generic PoV verifier using Jazzer.""" + jazzer_base: JazzerFuzzer = None + counter: int = 0 + first_solved: int = None + beepseed: BeepSeed = None + last_input: bytes = None + work_dir: Path = None + + def __init__(self, jazzer: JazzerFuzzer, beepseed: BeepSeed, work_dir: Path): + super().__init__() + self.jazzer_base = jazzer + self.beepseed = beepseed + self.work_dir = work_dir + + def check_crashes(self, input_data: Union[HexStringInput, BytesInput, PythonScriptInput]) -> str: + """A tool to check if a PoV exploits the vulnerability. + + Args: + input_data: The input data in one of the supported formats + + Returns: + The result of the verification + """ + self.counter += 1 + logger.info(f"Verifying input (attempt {self.counter})") + + try: + flexible_input = FlexibleInput(input_data=input_data) + input_bytes = flexible_input.to_bytes() + except Exception as e: + logger.error(f"Error processing input (attempt {self.counter}): {e}") + return f"Error processing input (attempt {self.counter}): {e}" + + self.last_input = input_bytes + + crashes = self._get_crash_for_input(input_bytes) + logger.info(f"All crashes detected: {crashes}") + crash_found = any([self._check_crash(crash) for crash in crashes]) + + if crash_found: + logger.info(f"Valid crash found in verifier (attempt {self.counter})!") + if self.first_solved is None: + self.first_solved = self.counter + return f"Exploit successful, crash found" + else: + logger.info(f"No valid crash found in verifier (attempt {self.counter})") + return "No crashes found, exploit unsuccessful" # TODO(fab1ano): add info whether sinkpoint was reached + + def _check_crash(self, crash: Any) -> bool: + """Check if the given crash matches the expected vulnerability. + + Args: + crash: The crash data to check + + Returns: + True if the crash matches the expected vulnerability, False otherwise + """ + class_name = self.beepseed.coord.class_name.replace("/", ".") + method_name = self.beepseed.coord.method_name + line_no = self.beepseed.coord.line_num + file_name = self.beepseed.coord.file_name + signature = f"{class_name}.{method_name}({file_name}:{line_no})" + logger.warning(f"{CRS_WARN} Checking crash for signature: {signature}") + + stack_trace = crash[3] + for stack_frame in stack_trace: + if signature in stack_frame: + return True + + return False + + def _get_crash_for_input(self, input_bytes: bytes) -> list[Any]: + """Run jazzer with the given input and check for crashes. + + Args: + input_bytes: The input bytes to be used as PoV + + Returns: + A list of crashes found, empty if none + """ + + # Define the run_id + input_hash = hashlib.sha256(input_bytes).hexdigest() + input_id = f"verify-{input_hash}" + run_id = f"id-{input_hash}" + + self.jazzer_base.add_corpus_file(input_bytes, input_id) + + with tempfile.TemporaryDirectory() as work_dir: + logger.info(f"Running verification in {work_dir}") + jazzer = self.jazzer_base.clone(work_dir=Path(work_dir)) + input_file = jazzer.add_corpus_file(input_bytes, "poc") + result_json = jazzer.fuzz(run_id, 60, verify_only=True) + logger.info("Finished jazzer run, checking results") + + fuzz_log_file = jazzer.fuzz_log + if fuzz_log_file.exists(): + # Save the fuzz log for debugging + saved_log = self.work_dir / f"verify-fuzz-{self.counter}.log" + with open(fuzz_log_file, "r") as src, open(saved_log, "w") as dst: + dst.write(src.read()) + logger.info(f"Saved fuzz log to {saved_log}") + + # Check for "Executed " in the log + with open(fuzz_log_file, "r") as f: + log_content = f.read() + if f"Executed {input_file}" not in log_content: + logger.warning(f"{CRS_ERR} Input file {input_file} was not executed during verification!") + if "ERROR: libFuzzer: timeout after " in log_content: + logger.warning(f"{CRS_ERR} Verification run timed out") + return [] + else: + raise RuntimeError(f"Input file {input_file} was not executed during verification!") + else: + logger.warning(f"{CRS_ERR} Fuzz log file {fuzz_log_file} does not exist!") + raise RuntimeError(f"Fuzz log file {fuzz_log_file} does not exist!") + + if result_json.exists(): + with open(result_json, "r") as f: + result = json.load(f) + # Check if there's an entry in result["fuzz_data"]["log_dedup_crash_over_time"] + # TODO(fab1ano): improve crash detection logic + if "fuzz_data" in result and "log_dedup_crash_over_time" in result["fuzz_data"]: + crashes = result["fuzz_data"]["log_dedup_crash_over_time"] + if crashes: + return crashes + + logger.info("No crashes found in verifier.") + return [] + + +class VerificationTool(BaseTool): + """A tool to verify if a PoV exploits the vulnerability.""" + name: str = "verify_exploit" + description: str = ("A tool to check if a PoV exploits the vulnerability. " + "It takes as input a PoV " + "and returns whether it successfully triggers the vulnerability " + "by running it against the target Java program using Jazzer.") + verifier: GenericPoVVerifier = None + + def __init__(self, verifier: GenericPoVVerifier): + super().__init__() + self.verifier = verifier + + +class HexStringVerificationTool(VerificationTool): + """PoV checker tool specialized for hex string input.""" + name: str = "verify_exploit_hex" + args_schema: Type[BaseModel] = HexStringInput + + def __init__(self, verifier: GenericPoVVerifier): + super().__init__(verifier) + + def _run(self, data: str) -> str: + return self.verifier.check_crashes(HexStringInput(data=data)) + + +class BytesVerificationTool(VerificationTool): + """PoV checker tool specialized for bytes array input.""" + name: str = "verify_exploit_bytes" + args_schema: Type[BaseModel] = BytesInput + + def __init__(self, verifier: GenericPoVVerifier): + super().__init__(verifier) + + def _run(self, data: bytes) -> str: + return self.verifier.check_crashes(BytesInput(data=data)) + + +class ScriptVerificationTool(VerificationTool): + """PoV checker tool specialized for Python script input.""" + name: str = "verify_exploit_script" + args_schema: Type[BaseModel] = PythonScriptInput + + def __init__(self, verifier: GenericPoVVerifier): + super().__init__(verifier) + + def _run(self, code: str) -> str: + return self.verifier.check_crashes(PythonScriptInput(code=code)) + + +class FilePathInput(BaseModel): + """Input schema for file path.""" + file_path: str = Field(description="Path to the file to read (absolute or relative)") + + +class UnifiedReadFileTool(BaseTool): + """Unified read file tool that checks both root directories.""" + name: str = "read_file" + description: str = "Read a file from either the OSS-Fuzz project directory or the challenge project source code. Provide either an absolute path or a relative path (will check both roots)." + args_schema: Type[BaseModel] = FilePathInput + + ossfuzz_root: str = None + source_root: str = None + + def __init__(self, ossfuzz_root: str = None, source_root: str = None): + super().__init__() + self.ossfuzz_root = ossfuzz_root + self.source_root = source_root + + def _run(self, file_path: str) -> str: + """Read file from one of the configured roots.""" + + path = Path(file_path) + + # If absolute path, check if it's within one of the roots + if path.is_absolute(): + if self.ossfuzz_root and path.is_relative_to(self.ossfuzz_root): + if path.exists() and path.is_file(): + return path.read_text() + else: + return f"Error: File not found: {file_path}" + elif self.source_root and path.is_relative_to(self.source_root): + if path.exists() and path.is_file(): + return path.read_text() + else: + return f"Error: File not found: {file_path}" + else: + return f"Error: Path {file_path} is not within allowed directories ({self.ossfuzz_root}, {self.source_root})" + + # If relative path, try both roots + else: + if self.ossfuzz_root: + ossfuzz_path = Path(self.ossfuzz_root) / file_path + if ossfuzz_path.exists() and ossfuzz_path.is_file(): + return ossfuzz_path.read_text() + + if self.source_root: + source_path = Path(self.source_root) / file_path + if source_path.exists() and source_path.is_file(): + return source_path.read_text() + + return f"Error: File not found in any configured directory: {file_path}" + + +class DirectoryPathInput(BaseModel): + """Input schema for directory path.""" + dir_path: str = Field(description="Path to the directory to list (absolute or relative)") + + +class UnifiedListDirectoryTool(BaseTool): + """Unified list directory tool that checks both root directories.""" + name: str = "list_directory" + description: str = "List contents of a directory from either the OSS-Fuzz project directory or the challenge project source code. Provide either an absolute path or a relative path (will check both roots)." + args_schema: Type[BaseModel] = DirectoryPathInput + + ossfuzz_root: str = None + source_root: str = None + + def __init__(self, ossfuzz_root: str = None, source_root: str = None): + super().__init__() + self.ossfuzz_root = ossfuzz_root + self.source_root = source_root + + def _run(self, dir_path: str) -> str: + """List directory contents from one of the configured roots.""" + + path = Path(dir_path) + + # If absolute path, check if it's within one of the roots + if path.is_absolute(): + if self.ossfuzz_root and path.is_relative_to(self.ossfuzz_root): + if path.exists() and path.is_dir(): + items = [item.name + "/" if item.is_dir() else item.name for item in path.iterdir()] + return "\n".join(sorted(items)) + else: + return f"Error: Directory not found: {dir_path}" + elif self.source_root and path.is_relative_to(self.source_root): + if path.exists() and path.is_dir(): + items = [item.name + "/" if item.is_dir() else item.name for item in path.iterdir()] + return "\n".join(sorted(items)) + else: + return f"Error: Directory not found: {dir_path}" + else: + return f"Error: Path {dir_path} is not within allowed directories ({self.ossfuzz_root}, {self.source_root})" + + # If relative path, try both roots + else: + if self.ossfuzz_root: + ossfuzz_path = Path(self.ossfuzz_root) / dir_path + if ossfuzz_path.exists() and ossfuzz_path.is_dir(): + items = [item.name + "/" if item.is_dir() else item.name for item in ossfuzz_path.iterdir()] + return "\n".join(sorted(items)) + + if self.source_root: + source_path = Path(self.source_root) / dir_path + if source_path.exists() and source_path.is_dir(): + items = [item.name + "/" if item.is_dir() else item.name for item in source_path.iterdir()] + return "\n".join(sorted(items)) + + return f"Error: Directory not found in any configured directory: {dir_path}" + + +class SearchPatternInput(BaseModel): + """Input schema for search pattern.""" + pattern: str = Field(description="File name pattern to search for (e.g., '*.java', 'Main.java')") + + +class UnifiedFileSearchTool(BaseTool): + """Unified file search tool that searches in both root directories.""" + name: str = "search_files" + description: str = "Search for files by name pattern in both the OSS-Fuzz project directory and the challenge project source code. Provide a search pattern (uses fnmatch; supports wildcards like *.java)." + args_schema: Type[BaseModel] = SearchPatternInput + + ossfuzz_root: str = None + source_root: str = None + + def __init__(self, ossfuzz_root: str = None, source_root: str = None): + super().__init__() + self.ossfuzz_root = ossfuzz_root + self.source_root = source_root + + def _run(self, pattern: str) -> str: + """Search for files matching pattern in configured roots.""" + + results = [] + + def search_in_directory(root: Path, pattern: str): + matches = [] + for path in root.rglob('*'): + if path.is_file() and (fnmatch.fnmatch(str(path), pattern) or + fnmatch.fnmatch(path.name, pattern)): + matches.append(str(path)) + return matches + + if self.ossfuzz_root: + ossfuzz_results = search_in_directory(Path(self.ossfuzz_root), pattern) + results.extend(ossfuzz_results) + + if self.source_root: + source_results = search_in_directory(Path(self.source_root), pattern) + results.extend(source_results) + + if results: + return "\n".join(sorted(results)) + else: + return f"No files found matching pattern: {pattern}" + + +class LiteLLMCostTracker(BaseCallbackHandler): + """Custom callback handler to track exact costs from LiteLLM server.""" + model: str + + def __init__(self, model: str): + super().__init__() + self.model = model + self.total_cost = 0.0 + self.total_tokens = 0 + self.total_prompt_tokens = 0 + self.total_completion_tokens = 0 + self.llm_calls = [] + + def on_llm_end( + self, + response: LLMResult, + **kwargs: Any, + ) -> None: + """Calculate cost when LLM call completes.""" + + input_tokens = 0 + output_tokens = 0 + + for gens in response.generations: + for gen in gens: + if gen.message.usage_metadata: + meta = gen.message.usage_metadata + input_tokens += meta.get('input_tokens', 0) + output_tokens += meta.get('output_tokens', 0) + + try: + prompt_cost, completion_cost = cost_per_token(self.model, input_tokens, output_tokens) + except Exception as e: + logger.warning(f"Error calculating cost for model {self.model}, trying with provider: {e}") + model_with_provider = get_with_model_provider(self.model) + try: + prompt_cost, completion_cost = cost_per_token(model_with_provider, input_tokens, output_tokens) + except Exception as e2: + logger.error(f"Failed to calculate cost with provider for model {model_with_provider}: {e2}") + prompt_cost, completion_cost = 0.0, 0.0 # Fallback to zero cost if we can't calculate + + # Update totals + self.total_cost += prompt_cost + completion_cost + self.total_tokens += input_tokens + output_tokens + self.total_prompt_tokens += input_tokens + self.total_completion_tokens += output_tokens + + # Store call details + call_info = { + "model": self.model, + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "cost": prompt_cost + completion_cost, + } + self.llm_calls.append(call_info) + + def get_summary(self) -> Dict[str, Any]: + """Get cost tracking summary.""" + return { + "total_cost": round(self.total_cost, 6), + "total_tokens": self.total_tokens, + "total_prompt_tokens": self.total_prompt_tokens, + "total_completion_tokens": self.total_completion_tokens, + "num_llm_calls": len(self.llm_calls), + "cost_breakdown": self.llm_calls + } + + def reset(self): + """Reset the cost tracker.""" + self.total_cost = 0.0 + self.total_tokens = 0 + self.total_prompt_tokens = 0 + self.total_completion_tokens = 0 + self.llm_calls = [] + + +class ExpkitExecutionContext: + """Custom AgentExecutor for exploit generation and verification.""" + llm: ChatLiteLLM + cost_tracker: LiteLLMCostTracker + tools: list[BaseTool] + agent: Runnable + agent_executor: AgentExecutor + verifier: GenericPoVVerifier + work_dir: Path = None + prompt: ChatPromptTemplate = None + + def __init__( + self, + model: str, + temperature: float, + jazzer: JazzerFuzzer, + beepseed: BeepSeed, + work_dir: Path, + output_formats: list[str] = None, + cp_meta: CPMetadata = None, + max_iterations: int = 30, + force_submission_threshold: float = 0.8 + ): + api_key = os.environ.get("LITELLM_KEY") + base_url = os.environ.get("AIXCC_LITELLM_HOSTNAME") + + self.work_dir = work_dir + + self.cost_tracker = LiteLLMCostTracker(model) + + # Initialize the language model + self.llm = ChatLiteLLM( + model=f"litellm_proxy/{model}", + api_key=api_key, + api_base=base_url, + temperature=temperature, + callbacks=[self.cost_tracker], + request_timeout=240 + ) + self.llm.streaming = False # Disable streaming for tool calling + + # Initialize the verifier + self.verifier = GenericPoVVerifier(jazzer, beepseed, work_dir) + + # Default to all formats if not specified + if output_formats is None: + output_formats = ["hexstring", "bytes", "script"] + + # Normalize format names to lowercase + output_formats = [fmt.lower() for fmt in output_formats] + + # Create the tools list based on selected formats + verification_tools = [] + + if "hexstring" in output_formats: + verification_tools.append(HexStringVerificationTool(self.verifier)) + + if "bytes" in output_formats: + verification_tools.append(BytesVerificationTool(self.verifier)) + + if "script" in output_formats: + verification_tools.append(ScriptVerificationTool(self.verifier)) + + # Add code exploration tools if cp_meta is provided + code_exploration_tools = [] + if cp_meta is not None: + # Get both root paths + ossfuzz_root = str(cp_meta.get_proj_path()) if cp_meta.get_proj_path() else None + source_root = str(cp_meta.get_repo_src_path()) if cp_meta.get_repo_src_path() else None + + # Create unified tools that work with both roots + if ossfuzz_root or source_root: + read_tool = UnifiedReadFileTool( + ossfuzz_root=ossfuzz_root, + source_root=source_root + ) + code_exploration_tools.append(read_tool) + + list_tool = UnifiedListDirectoryTool( + ossfuzz_root=ossfuzz_root, + source_root=source_root + ) + code_exploration_tools.append(list_tool) + + search_tool = UnifiedFileSearchTool( + ossfuzz_root=ossfuzz_root, + source_root=source_root + ) + code_exploration_tools.append(search_tool) + + logger.info(f"Added unified file access tools with roots: ossfuzz={ossfuzz_root}, source={source_root}") + else: + logger.warning(f"{CRS_WARN} Could not determine source roots from cp_meta, code exploration tools not added") + + self.tools = verification_tools + code_exploration_tools + + # Create the prompt template + self.prompt = ChatPromptTemplate.from_messages([ + ("system", "You are an expert Java security researcher and exploit developer. " + "The user provides you with details about a vulnerability in a Java program " + "and an input that reaches the sinkpoint of the vulnerability. " + "It is your task to use the given information and the tools to develop a working exploit. " + "Be precise and technical in your analyses and when creating the PoV. " + "Make sure to use the verification tool to check if your PoV works. "), + ("user", "{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ]) + + # Create the agent + self.agent = create_tool_calling_agent(self.llm, self.tools, self.prompt) + + # Create the agent executor + self.agent_executor = AgentExecutor( + agent=self.agent, + tools=self.tools, + verbose=False, # Set to True to see the agent's thought process + handle_parsing_errors=True, + return_intermediate_steps=True, + max_iterations=max_iterations, # Limit iterations to prevent infinite loops + ) + + @retry( + wait=wait_exponential(multiplier=1, min=4, max=60), + stop=stop_after_attempt(10), + retry=retry_if_exception_type(RateLimitError) + ) + def run(self, query: str) -> Dict: + return self.agent_executor.invoke({"input": query}) + + def solution(self) -> str | None: + """Get the last solution in hex format, or None if not set.""" + return self.verifier.last_input.hex().lower() if self.verifier.last_input else None + + def tool_call_solved(self) -> int: + return self.verifier.first_solved + + def get_cost_summary(self) -> str: + summary = self.cost_tracker.get_summary() + return (f"Total Cost: ${summary['total_cost']}, " + f"Total Tokens: {summary['total_tokens']} " + f"(Prompt: {summary['total_prompt_tokens']}, " + f"Completion: {summary['total_completion_tokens']}), " + f"LLM Calls: {summary['num_llm_calls']}") + + def get_cost_summary_verbose(self) -> str: + summary = self.cost_tracker.get_summary() + lines = [ + "\n=== Cost Summary ===", + f"Total Cost: ${summary['total_cost']}", + f"Total Tokens: {summary['total_tokens']}", + f" - Prompt Tokens: {summary['total_prompt_tokens']}", + f" - Completion Tokens: {summary['total_completion_tokens']}", + f"Number of LLM Calls: {summary['num_llm_calls']}", + "Cost Breakdown per Call:" + ] + for call in summary['cost_breakdown']: + lines.append( + f"Prompt Tokens: {call['prompt_tokens']}, " + f"Completion Tokens: {call['completion_tokens']}, " + f"Total Tokens: {call['total_tokens']}, " + f"Cost: ${call['cost']}" + ) + lines.append("====================\n") + return "\n".join(lines) diff --git a/crs/expkit/expkit/cpmeta.py b/crs/expkit/expkit/cpmeta.py index 3787e1393..c0c75de9d 100755 --- a/crs/expkit/expkit/cpmeta.py +++ b/crs/expkit/expkit/cpmeta.py @@ -89,3 +89,21 @@ def resolve_frame_to_file_path(self, frame) -> str | None: if not frame or "class_name" not in frame or "file_name" not in frame: return None return self.resolve_file_path(frame["class_name"], frame["file_name"]) + + def get_proj_path(self) -> Path | None: + """Get the OSS-Fuzz project directory path from metadata. + + Returns: + Path to the OSS-Fuzz project directory (harnesses, build scripts, etc.) + """ + proj_path = self.metadata.get("proj_path", None) + return Path(proj_path) if proj_path else None + + def get_repo_src_path(self) -> Path | None: + """Get the challenge project source code directory path from metadata. + + Returns: + Path to the challenge project source code directory + """ + repo_src_path = self.metadata.get("repo_src_path", None) + return Path(repo_src_path) if repo_src_path else None diff --git a/crs/expkit/expkit/exploit.py b/crs/expkit/expkit/exploit.py index 6aba8e0fe..6e96f2e52 100755 --- a/crs/expkit/expkit/exploit.py +++ b/crs/expkit/expkit/exploit.py @@ -8,7 +8,6 @@ from .beepobjs import BeepSeed from .cpmeta import CPMetadata -from .llm import LLMClient from .redis import RedisCacheClient from .sinkpoint_beep import SinkpointExpTool from .utils import CRS_ERR_LOG, CRS_WARN_LOG @@ -31,75 +30,43 @@ print(f"{CRS_ERR} Failed to install OpenTelemetry logger: {e}.") -def parse_model_weights(models_str: str) -> list: - """Parse a string of model:weight pairs into a weighted list. - - Format: "model1:weight1,model2:weight2,..." - Example: "gpt-4o:10,o1-preview:20,none:10" - - Returns a list with models repeated according to their weights. - """ - if not models_str: - return [] - - result = [] - try: - for pair in models_str.split(","): - if ":" not in pair: - logger.warning( - f"{CRS_ERR} Invalid model:weight pair: {pair}, expected format 'model:weight'" - ) - continue - - model, weight_str = pair.split(":", 1) - try: - weight = int(weight_str) - if weight <= 0: - logger.warning( - f"{CRS_ERR} Invalid weight {weight} for model {model}, must be positive" - ) - continue - result.extend([model.strip()] * weight) - except ValueError: - logger.warning( - f"{CRS_ERR} Invalid weight value: {weight_str}, must be an integer" - ) - except Exception as e: - logger.error(f"{CRS_ERR} parsing model weights: {e}") - - return result - - def run( beepseed_path: str, output_path: str, metadata_path: str, exp_time: int = 300, verbose: bool = False, + skip_redis_cache: bool = False, workdir: str = None, - gen_models_str: str = None, - x_models_str: str = None, + gen_model: str = None, + output_formats: list[str] = None, + max_iterations: int = 30, + trim_context: bool = True, ): """Run the appropriate exploitation based on BeepSeed type""" # TODO: extend beepseed types and exploit tools + result = {"status": False} try: workdir_path = Path(workdir) if workdir else None + logger.info(f"Generation model: {gen_model}") - gen_models = parse_model_weights(gen_models_str) - logger.info(f"Generation models distribution: {gen_models}") + if output_formats is None: + output_formats = ["hexstring", "bytes", "script"] + logger.info(f"Output formats enabled: {output_formats}") - x_models = parse_model_weights(x_models_str) - logger.info(f"Extraction models distribution: {x_models}") + cp_meta = CPMetadata(metadata_path) exp_tool = SinkpointExpTool( - llm_client=LLMClient(verbose=verbose), redis_client=RedisCacheClient(), beepseed=BeepSeed.frm_beep_file(beepseed_path), exp_time=exp_time, - cp_meta=CPMetadata(metadata_path), + cp_meta=cp_meta, workdir=workdir_path, - gen_models=gen_models, - x_models=x_models, + gen_model=gen_model, + skip_redis_cache=skip_redis_cache, + output_formats=output_formats, + max_iterations=max_iterations, + trim_context=trim_context, ) result = exp_tool.exploit() @@ -107,7 +74,7 @@ def run( except Exception as e: err_str = f"Exception: {e}\n{traceback.format_exc()}" logger.error(f"{CRS_ERR} Exploitation failed with {err_str}") - result = {"status": False, "error": err_str} + result["error"] = err_str finally: try: @@ -137,18 +104,32 @@ def main(): help="Time limit for exploitation in seconds", ) parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") + parser.add_argument("--no-cache", action="store_true", help="Disable Redis cache lookup") parser.add_argument( "--workdir", default=None, help="Working directory for exploitation artifacts" ) parser.add_argument( - "--gen-models", - default="o1-preview:1,claude-3-7-sonnet-20250219:1", - help="Comma-separated list of generation models with weights. Format: 'model1:weight1,model2:weight2,...'. Example: 'o1-preview:10,claude-3-7-sonnet-20250219:20,none:5'", + "--gen-model", + default="gpt-5", + help="LLM model used by the exploitation agent. Default: gpt-5", + ) + parser.add_argument( + "--output-formats", + nargs="+", + choices=["hexstring", "bytes", "script", "all"], + default=["all"], + help="Output formats to enable for agent tools. Choices: hexstring, bytes, script, all. Default: all", ) parser.add_argument( - "--x-models", - default="gpt-4o:1,o3-mini:1", - help="Comma-separated list of extraction models with weights. Format: 'model1:weight1,model2:weight2,...'. Example: 'gpt-4o:10,o3-mini:20,none:5'", + "--max-iterations", + type=int, + default=30, + help="Maximum number of agent iterations before forced PoV submission. Default: 30", + ) + parser.add_argument( + "--no-trim-context", + action="store_true", + help="Disable context trimming - include full files in prompts instead of snippets. Default: False (trimming enabled)", ) args = parser.parse_args() @@ -156,15 +137,23 @@ def main(): if args.verbose: logging.getLogger().setLevel(logging.DEBUG) + # Process output formats + output_formats = args.output_formats + if "all" in output_formats: + output_formats = ["hexstring", "bytes", "script"] + run( args.beepseed, args.output, args.metadata, args.exp_time, args.verbose, + args.no_cache, args.workdir, - args.gen_models, - args.x_models, + args.gen_model, + output_formats, + args.max_iterations, + trim_context=not args.no_trim_context, ) diff --git a/crs/expkit/expkit/fuzzer/jazzer.py b/crs/expkit/expkit/fuzzer/jazzer.py index fc6ece59f..34248a803 100755 --- a/crs/expkit/expkit/fuzzer/jazzer.py +++ b/crs/expkit/expkit/fuzzer/jazzer.py @@ -118,7 +118,7 @@ def _write_command_script(self, cwd: str): command_sh.chmod(0o755) return command_sh - def _init_result_json(self, fuzz_id: str, fuzz_time: int, mem_size: int): + def _init_result_json(self, fuzz_id: str, fuzz_time: int, mem_size: int, verify_only: bool = False): if not self.result_json.parent.exists(): self.result_json.parent.mkdir(parents=True, exist_ok=True) @@ -129,7 +129,9 @@ def _init_result_json(self, fuzz_id: str, fuzz_time: int, mem_size: int): self.env["FUZZ_CUSTOM_ARGS"] = " ".join(self.custom_args) self.env["FUZZ_TARGET_HARNESS"] = self.target_harness if self.custom_sink_conf_path is not None: - self.env["FUZZ_CUSTOM_SINK_CONF"] = self.custom_sink_conf_path + self.env["FUZZ_CUSTOM_SINK_CONF"] = str(self.custom_sink_conf_path) + if verify_only: + self.env["FUZZ_VERIFY_ONLY"] = "1" init_data = { "cp": self.cp_name, @@ -159,12 +161,12 @@ def _init_result_json(self, fuzz_id: str, fuzz_time: int, mem_size: int): with open(self.result_json, "w") as f: json.dump(init_data, f, indent=2) - def fuzz(self, fuzz_id: str, fuzz_time: int, mem_size: int = 4096) -> Path: + def fuzz(self, fuzz_id: str, fuzz_time: int, mem_size: int = 4096, verify_only: bool = False): try: self._write_dict_file() logger.info(f"Fuzz dict file has {len(self.dict_values)} entries") - self._init_result_json(fuzz_id, fuzz_time, mem_size) + self._init_result_json(fuzz_id, fuzz_time, mem_size, verify_only=verify_only) cwd = f"/tmp-{fuzz_id}" command_sh = self._write_command_script(cwd) @@ -205,3 +207,16 @@ def fuzz(self, fuzz_id: str, fuzz_time: int, mem_size: int = 4096) -> Path: err_str = f"{CRS_ERR} Fuzzing failed: {str(e)}" logger.error(f"{err_str} with traceback:\n{traceback.format_exc()}") raise RuntimeError(err_str) + + def clone(self, work_dir: Path): + return JazzerFuzzer( + jazzer_dir=self.jazzer_dir, + work_dir=work_dir, + cp_name=self.cp_name, + target_harness=self.target_harness, + fuzz_target=self.fuzz_target, + target_classpath=self.target_classpath, + custom_sink_conf_path=self.custom_sink_conf_path, + cpu_id=self.cpu_id, + custom_args=list(self.custom_args), + ) diff --git a/crs/expkit/expkit/fuzzer/scripts/jazzer_postprocessing.py b/crs/expkit/expkit/fuzzer/scripts/jazzer_postprocessing.py index 3ae73b6c4..bbbb19f9a 100755 --- a/crs/expkit/expkit/fuzzer/scripts/jazzer_postprocessing.py +++ b/crs/expkit/expkit/fuzzer/scripts/jazzer_postprocessing.py @@ -64,6 +64,12 @@ CRASH_ARTIFACT_LINE_PTRN = re.compile( r"^(\d+)\sartifact_prefix=.*; Test unit written to .*/artifacts/((crash|timeout)-[a-z0-9]+)" ) +OLD_CRASH_REPRO_ARTIFACT_LINE_PTRN = re.compile( + r"^Executed (\S*) in \d+ ms" +) +CRASH_REPRO_ARTIFACT_LINE_PTRN = re.compile( + r"^(\d+)\sExecuted (\S*) in \d+ ms" +) # NOTE: TODO: libafl-jazzer currently does not log timeout artifact info LIBAFL_JAZZER_CRASH_ARTIFACT_LINE_PTRN = re.compile( r"^(\d+)\s\[libafl\] Received jazzer death callback! Dumping corpus as crash to .*/artifacts/(crash-[a-z0-9]+)" @@ -349,28 +355,37 @@ def _parse_artifact_line( ) -> Optional[Tuple]: global is_libafl_jazzer + # Define patterns based on mode if is_libafl_jazzer: - match = LIBAFL_JAZZER_CRASH_ARTIFACT_LINE_PTRN.match(line) - if match: - timestamp, artifact = int(match.group(1)), match.group(2) - elapsed_time = ( - timestamp - initial_timestamp if initial_timestamp is not None else None - ) - return elapsed_time, artifact - return None + patterns_with_timestamp = [LIBAFL_JAZZER_CRASH_ARTIFACT_LINE_PTRN] + patterns_without_timestamp = [] else: - match = CRASH_ARTIFACT_LINE_PTRN.match(line) + patterns_with_timestamp = [ + CRASH_ARTIFACT_LINE_PTRN, + CRASH_REPRO_ARTIFACT_LINE_PTRN + ] + patterns_without_timestamp = [ + OLD_CRASH_ARTIFACT_LINE_PTRN, + OLD_CRASH_REPRO_ARTIFACT_LINE_PTRN + ] + + # Check patterns with timestamp (groups 1 and 2) + for pattern in patterns_with_timestamp: + match = pattern.match(line) if match: timestamp, artifact = int(match.group(1)), match.group(2) elapsed_time = ( timestamp - initial_timestamp if initial_timestamp is not None else None ) return elapsed_time, artifact - else: - match = OLD_CRASH_ARTIFACT_LINE_PTRN.match(line) - if match: - artifact = match.group(1) - return None, artifact + + # Check patterns without timestamp (only group 1) + for pattern in patterns_without_timestamp: + match = pattern.match(line) + if match: + artifact = match.group(1) + return None, artifact + return None diff --git a/crs/expkit/expkit/fuzzer/scripts/run-jazzer.sh b/crs/expkit/expkit/fuzzer/scripts/run-jazzer.sh index e72784f49..c4eeb69f9 100755 --- a/crs/expkit/expkit/fuzzer/scripts/run-jazzer.sh +++ b/crs/expkit/expkit/fuzzer/scripts/run-jazzer.sh @@ -70,6 +70,12 @@ do unset SKIP_SEED_CORPUS fi + if [[ -n \${FUZZ_VERIFY_ONLY} ]]; then + echo "FUZZ_VERIFY_ONLY is set, skip fuzzing" + sleep 1s + break + fi + stdbuf -e 0 -o 0 \ run_fuzzer ${FUZZ_TARGET_HARNESS} \ "\$@" || echo @@@@@ exit code of Jazzer is $? @@@@@ >&2 diff --git a/crs/expkit/expkit/llm.py b/crs/expkit/expkit/llm.py deleted file mode 100755 index 16c61ba93..000000000 --- a/crs/expkit/expkit/llm.py +++ /dev/null @@ -1,178 +0,0 @@ -#!/usr/bin/env python3 - -import logging -import os -import time - -import litellm -import openai -from litellm import completion, completion_cost - -from .utils import CRS_ERR_LOG, CRS_WARN_LOG - -CRS_ERR = CRS_ERR_LOG("llm") -CRS_WARN = CRS_WARN_LOG("llm") - - -logger = logging.getLogger(__name__) - - -class LLMClient: - """ - A client for multi-provider LLM API calls using litellm with unified interface, - error handling, and usage tracking. - """ - - def __init__( - self, timeout: int = 240, max_retries: int = 10, verbose: bool = False - ): - self.api_key = os.environ.get("LITELLM_KEY") or os.environ.get("OSS_CRS_LLM_API_KEY") - self.base_url = os.environ.get("AIXCC_LITELLM_HOSTNAME") or os.environ.get("OSS_CRS_LLM_API_URL") - - if not self.api_key: - logger.error(f"{CRS_WARN} LITELLM_KEY environment variable not set") - raise ValueError("LITELLM_KEY environment variable must be set") - - if not self.base_url: - logger.error( - f"{CRS_ERR} AIXCC_LITELLM_HOSTNAME environment variable not set" - ) - raise ValueError("AIXCC_LITELLM_HOSTNAME environment variable must be set") - - # Configure litellm defaults - litellm.drop_params = True - litellm.request_timeout = timeout - litellm.num_retries = max_retries - litellm.set_verbose = verbose - - # Usage tracking - self.total_cost = 0.0 - self.total_tokens = 0 - self.total_prompt_tokens = 0 - self.total_completion_tokens = 0 - self.request_count = 0 - - def completion( - self, - prompt: str, - model: str, - system_prompt: str | None = None, - temperature: float | None = None, - tools: list[dict] | None = None, - tool_choice: str | dict | None = None, - ) -> dict: - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": prompt}) - - params = { - "model": model, - "messages": messages, - "api_key": self.api_key, - "base_url": self.base_url, - "request_timeout": litellm.request_timeout, - } - - params["temperature"] = temperature if temperature is not None else 1.0 - if tools: - params["tools"] = tools - if tool_choice: - params["tool_choice"] = tool_choice - - if "claude" in model: - params["max_tokens"] = 32000 - params["thinking"] = {"type": "enabled", "budget_tokens": 20000} - params["temperature"] = 1 - - try: - start_time = time.time() - - if "claude" in model: - response = completion(**params) - else: - cli = openai.OpenAI( - api_key=self.api_key, - base_url=self.base_url, - ) - response = cli.chat.completions.create(model=model, messages=messages) - - return self._process_response(response, params["model"], start_time) - - except Exception as e: - logger.error( - f"{CRS_WARN} Error in completion request: {str(e)}", exc_info=True - ) - raise - - def _process_response(self, response, model, start_time): - first_choice = response.choices[0] - if hasattr(first_choice, "message") and hasattr( - first_choice.message, "content" - ): - content = first_choice.message.content - else: - content = str(first_choice) - - cost = completion_cost(completion_response=response) - self._update_metrics(response, cost) - - elapsed = time.time() - start_time - logger.info(f"Request completed in {elapsed:.2f}s, cost: ${cost:.6f}") - - return { - "content": content, - "cost": cost, - "elapsed_time": elapsed, - "model": model, - "usage": ( - response.usage.model_dump() if hasattr(response, "usage") else None - ), - "tool_calls": ( - first_choice.message.tool_calls - if hasattr(first_choice.message, "tool_calls") - else None - ), - "raw_response": response, - } - - def _update_metrics(self, response, cost): - """Update internal tracking metrics""" - self.total_cost += cost - self.request_count += 1 - - if hasattr(response, "usage"): - self.total_tokens += response.usage.total_tokens - self.total_prompt_tokens += response.usage.prompt_tokens - self.total_completion_tokens += response.usage.completion_tokens - - def get_usage_stats(self) -> dict: - """Get detailed usage statistics for all requests made through this client""" - return { - "total_cost": self.total_cost, - "total_tokens": self.total_tokens, - "total_prompt_tokens": self.total_prompt_tokens, - "total_completion_tokens": self.total_completion_tokens, - "request_count": self.request_count, - "average_cost_per_request": ( - self.total_cost / self.request_count if self.request_count else 0 - ), - "average_tokens_per_request": ( - self.total_tokens / self.request_count if self.request_count else 0 - ), - } - - def print_usage_stats(self) -> str: - """Format usage statistics as a human-readable string""" - stats = self.get_usage_stats() - - return ( - f"LLM Usage Statistics:\n" - f" Total Requests: {stats['request_count']}\n" - f" Total Cost: ${stats['total_cost']:.6f}\n" - f" Total Tokens: {stats['total_tokens']}\n" - f" - Prompt Tokens: {stats['total_prompt_tokens']}\n" - f" - Completion Tokens: {stats['total_completion_tokens']}\n" - f" Average Cost per Request: ${stats['average_cost_per_request']:.6f}\n" - f" Average Tokens per Request: {int(stats['average_tokens_per_request'])}\n" - ) diff --git a/crs/expkit/expkit/sinkpoint_beep/exp.py b/crs/expkit/expkit/sinkpoint_beep/exp.py index bab363d98..c1849fd96 100755 --- a/crs/expkit/expkit/sinkpoint_beep/exp.py +++ b/crs/expkit/expkit/sinkpoint_beep/exp.py @@ -1,22 +1,19 @@ #!/usr/bin/env python3 +from datetime import datetime import json import logging import os -import random import traceback import uuid from pathlib import Path -import litellm -import openai - from ..beepobjs import BeepSeed from ..cpmeta import CPMetadata from ..fuzzer.jazzer import JazzerFuzzer -from ..llm import LLMClient from ..redis import RedisCacheClient from ..utils import CRS_ERR_LOG, CRS_WARN_LOG, get_usable_cpu_id +from ..agent import ExpkitExecutionContext from .prompt import PromptGenerator logger = logging.getLogger(__name__) @@ -31,39 +28,28 @@ class SinkpointExpTool: def __init__( self, - llm_client: LLMClient, redis_client: RedisCacheClient, beepseed: BeepSeed, exp_time: int, cp_meta: CPMetadata, workdir: Path = None, - gen_models: list = None, - x_models: list = None, + gen_model: str = None, + skip_redis_cache: bool = False, + output_formats: list[str] = None, + max_iterations: int = 30, + trim_context: bool = True, ): self.beepseed = beepseed - self.llm_client = llm_client self.redis_client = redis_client - self.gen_models = gen_models - self.x_models = x_models - self.gen_model = self._pick_gen_model() - self.x_model = self._pick_x_model() + self.gen_model = gen_model self.exp_time = exp_time self.cp_meta = cp_meta self.workdir = workdir self.cpu_id = get_usable_cpu_id() - self.prompt_generator = PromptGenerator(cp_meta, beepseed) - - def _pick_gen_model(self) -> str: - """Select a model for generating POCs based on weighted probabilities.""" - selected = random.choice(self.gen_models) - logger.info(f"Selected generation model: {selected}") - return selected - - def _pick_x_model(self) -> str: - """Select a model for extracting hex strings based on weighted probabilities.""" - selected = random.choice(self.x_models) - logger.info(f"Selected extraction model: {selected}") - return selected + self.prompt_generator = PromptGenerator(cp_meta, beepseed, trim_context=trim_context) + self.skip_redis_cache = skip_redis_cache + self.output_formats = output_formats if output_formats is not None else ["hexstring", "bytes", "script"] + self.max_iterations = max_iterations def check_exp_status(self, fuzz_log: Path) -> bool: if not fuzz_log.exists(): @@ -164,20 +150,44 @@ def _add_beepseed_to_corpus(self, jazzer): logger.error(f"{CRS_ERR} Failed to add beepseed to corpus: {e}") return None - def _add_poc_to_corpus(self, jazzer): + def _serialize_steps(self, result: dict) -> str: + """Serialize intermediate steps to a string.""" + + result_str = [] + + result_str.append("=" * 80 + "\n") + result_str.append("AGENT THOUGHT PROCESS LOG\n") + result_str.append("=" * 80 + "\n\n") + result_str.append(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") + + result_str.append(f"INPUT:\n{result.get('input', '')}\n\n") + result_str.append("-" * 80 + "\n\n") + + # Write each step + for i, step in enumerate(result.get("intermediate_steps", []), 1): + agent_action, observation = step + + result_str.append(f"STEP {i}:\n\n") + result_str.append(f"Thought/Reasoning:\n{agent_action.log}\n\n") + result_str.append(f"Action:\n Tool: {agent_action.tool}\n") + result_str.append(f" Input: {agent_action.tool_input}\n\n") + result_str.append(f"Observation/Result:\n{observation}\n\n") + result_str.append("-" * 80 + "\n\n") + + result_str.append(f"FINAL OUTPUT:\n{result.get('output', '')}\n\n") + result_str.append("=" * 80 + "\n") + + return "".join(result_str) + + + def _add_poc_to_corpus(self, jazzer, execution_context): """Generate POC content using LLM and add to corpus.""" try: - # Check if either model is 'none' to skip LLM queries - if self.gen_model.lower() == "none" or self.x_model.lower() == "none": - logger.info( - f"Skipping POC generation as model selection includes 'none' (gen_model={self.gen_model})" - ) - return None - logger.info("Generating POC using LLM") + steps_until_solved = None # check cache first x_hexstr = self.redis_client.get(self.beepseed, self.gen_model, "x_hexstr") - if x_hexstr is not None: + if x_hexstr is not None and not self.skip_redis_cache: logger.info( f"CACHE: Found cached x_hexstr in Redis: {x_hexstr} for {self.beepseed.redis_key()}" ) @@ -186,67 +196,50 @@ def _add_poc_to_corpus(self, jazzer): f"CACHE: No cached x_hexstr found, generating new one for {self.beepseed.redis_key()}" ) try: - poc_content = self.redis_client.get( - self.beepseed, self.gen_model, "poc_content" - ) - if poc_content is None: - logger.info( - f"CACHE: No cached poc_content found, generating new one for {self.beepseed.redis_key()}" - ) + poc_prompt = self.prompt_generator.generate_poc_prompt() + if self.workdir: + prompt_file = self.workdir / "poc_prompt.txt" + with open(prompt_file, "w") as f: + f.write(poc_prompt) + logger.info(f"Initial prompt saved to {prompt_file} ({len(poc_prompt)} chars)") + poc_response = execution_context.run(poc_prompt) + # Save response to a file in the workdir for debugging + if self.workdir: + if 'output' in poc_response: + resp_file = self.workdir / f"poc_response.txt" + with open(resp_file, "w") as f: + f.write(poc_response['output']) + logger.info(f"LLM response saved to {resp_file}") + if 'intermediate_steps' in poc_response: + steps_file = self.workdir / f"intermediate_steps.txt" + with open(steps_file, "w") as f: + f.write(self._serialize_steps(poc_response)) + logger.info(f"LLM response saved to {steps_file}") + + # Also parse the tool result to determine in which step the solution was found + for i, step in enumerate(poc_response.get("intermediate_steps", []), 1): + _, observation = step + if "Exploit successful" in observation: + steps_until_solved = i + logger.info(f"Solution found at step {steps_until_solved}") + break + x_hexstr = execution_context.solution() + logger.info(f"LLM solution: {x_hexstr}") - poc_prompt = self.prompt_generator.generate_poc_prompt() - poc_response = self.llm_client.completion( - prompt=poc_prompt, - model=self.gen_model, - temperature=1.0, - ) - poc_content = poc_response["content"] - - logger.info( - f"CACHE: Caching poc_content for {self.beepseed.redis_key()}" - ) - self.redis_client.set( - self.beepseed, self.gen_model, "poc_content", poc_content - ) - - x_hexstr_prompt = self.prompt_generator.generate_x_hexstr_prompt( - poc_content - ) - x_response = self.llm_client.completion( - prompt=x_hexstr_prompt, - model=self.x_model, - temperature=0.1, - ) - x_hexstr = x_response["content"] if x_hexstr is None: logger.warning( "CACHE: LLM returned None for x_hexstr, set it as empty string" ) x_hexstr = "" - except ( - litellm.InternalServerError, - litellm.Timeout, - litellm.ServiceUnavailableError, - litellm.RateLimitError, - openai.APITimeoutError, - openai.RateLimitError, - openai.InternalServerError, - openai.APIStatusError, - ) as e: - logger.info( - f"CACHE: Meet LLM error: {e}, do not cache POC hex string, will retry later: {traceback.format_exc()}" - ) - return None - except Exception as e: - logger.error( - f"CACHE: Meet unexpected error while generating POC: {e}, will not retry: {traceback.format_exc()}" + logger.info( + f"Failed to generate a poc: {e}, unable to cache POC hex string, will retry later: {traceback.format_exc()}" ) - x_hexstr = "" + return None, None logger.info( - f"CACHE: Set POC hex string in Redis cache for {self.beepseed.redis_key()}" + f"CACHE: Setting POC hex string in Redis cache for {self.beepseed.redis_key()}" ) self.redis_client.set( self.beepseed, self.gen_model, "x_hexstr", x_hexstr @@ -258,16 +251,16 @@ def _add_poc_to_corpus(self, jazzer): if poc_bytes: poc_file = jazzer.add_corpus_file(poc_bytes, "poc") logger.info(f"Added POC to corpus ({len(poc_bytes)} bytes)") - return poc_file + return poc_file, steps_until_solved except Exception as e: logger.warning(f"Invalid LLM-generated hex string in POC content: {e}") - return None + return None, None except Exception as e: logger.error( f"Failed to generate and add POC to corpus: {e} {traceback.format_exc()}" ) - return None + return None, None def exploit(self) -> dict: """Perform sinkpoint beepseed exploitation.""" @@ -292,7 +285,7 @@ def exploit(self) -> dict: ) logger.info( - f"Initializing fuzzer for {target_harness} with classpath {target_classpath}" + f"Initializing fuzzer for {target_harness} with classpath {target_classpath} (workdir={work_dir})" ) jazzer = JazzerFuzzer( @@ -309,27 +302,50 @@ def exploit(self) -> dict: ], ) + execution_context = ExpkitExecutionContext( + self.gen_model, 1.0, + jazzer, self.beepseed, + work_dir, + self.output_formats, + self.cp_meta, + self.max_iterations + ) + # self._dump_deepgen_task() self._add_beepseed_to_corpus(jazzer) - self._add_poc_to_corpus(jazzer) + poc_file, steps_until_solved = self._add_poc_to_corpus(jazzer, execution_context) + + exp_succ = False + if not poc_file: + logger.error(f"{CRS_ERR} No valid POC added to corpus, skipping jazzer") + result = { + "status": False, + "error": "No valid POC added to corpus, skipping exploitation", + "workdir": str(work_dir), + "fuzz_id": fuzz_id, + } - logger.info(f"Running fuzzer for {self.exp_time}s with ID {fuzz_id}") + else: + logger.info(f"Running fuzzer for {self.exp_time}s with ID {fuzz_id}") - result_json = jazzer.fuzz( - fuzz_id=fuzz_id, fuzz_time=self.exp_time, mem_size=4096 - ) - exp_succ = self.check_exp_status(jazzer.fuzz_log) - result = { - "status": exp_succ, - "cp_name": self.cp_meta.get_cp_name(), - "coordinate": self.beepseed.coord.to_dict(), - "workdir": str(work_dir), - "fuzz_id": fuzz_id, - "results_json": str(result_json) if result_json else None, - } + result_json = jazzer.fuzz( + fuzz_id=fuzz_id, fuzz_time=self.exp_time, mem_size=4096 + ) + exp_succ = self.check_exp_status(jazzer.fuzz_log) + tool_call_solved = execution_context.tool_call_solved() + result = { + "status": exp_succ, + "tool_call_solved": tool_call_solved, + "steps_until_solved": steps_until_solved, + "cp_name": self.cp_meta.get_cp_name(), + "coordinate": self.beepseed.coord.to_dict(), + "workdir": str(work_dir), + "fuzz_id": fuzz_id, + "results_json": str(result_json) if result_json else None, + } logger.info(f"Exploitation completed with status: {exp_succ}") - logger.info(self.llm_client.print_usage_stats()) + logger.info(f"Cost summary: {execution_context.get_cost_summary_verbose()}") return result except Exception as e: diff --git a/crs/expkit/expkit/sinkpoint_beep/prompt.py b/crs/expkit/expkit/sinkpoint_beep/prompt.py index 8c1f84403..a0eeb6593 100755 --- a/crs/expkit/expkit/sinkpoint_beep/prompt.py +++ b/crs/expkit/expkit/sinkpoint_beep/prompt.py @@ -86,6 +86,12 @@ class PromptGenerator: """, "sink-batik-TranscoderInput": """ The code is potentially vulnerable if the first argument of TranscoderInput can be controlled. Typically, given one type of transcoder input, such as a svg file,the attacker can embed payload by crafting the external resource contained inside the svg, such as xlink:href or href in svg file. If the target program doesn't securely handle the embedded metadata, it can lead to various consequences depending on the input and the payload, such as File Path Traversal, Server Side Request Forgery (SSRF), Remote Code Execution (RCE), or Denial of Service (DoS) conditions. In our context, the PoC needs to be able to trigger the Jazzer sanitizer hooked functions, such as initializing a class 'jaz.Zer', invoking a system command called 'jazze', or accessing any external URL (such as websites or host/IP:port combinations), accessing a file named 'jazzer-traversal', or causing a timeout or OOM condition for DoS, etc. +""", + "sink-ScriptEngineInjection": """ +The code is potentially vulnerable to a script injection. Inputs that contain the string '"jaz"+"zer"' (including the double quotes) during execution will be detected by Jazzer as valid proof of concept. +""", + "sink-RemoteCodeExecution": """ +The code is potentially vulnerable to Remote Code Execution (RCE). This can happen when the code uses unsafe deserialization, unsafe reflection, or other mechanisms that allow for dynamic code execution. Inputs that cause the execution to trigger any Jazzer sanitizer hooked functions will be detected as valid proof of concept. This includes scenarios such as initializing a class 'jaz.Zer' and invoking a system command called 'jazze'. """, } @@ -101,10 +107,11 @@ def _load_template(cls, filename): with open(template_path) as f: return f.read() - def __init__(self, cp_meta: CPMetadata, beepseed: BeepSeed): + def __init__(self, cp_meta: CPMetadata, beepseed: BeepSeed, trim_context: bool = True): self.cp_meta = cp_meta self.cp_name = cp_meta.get_cp_name() self.beepseed = beepseed + self.trim_context = trim_context if PromptGenerator._poc_template is None: PromptGenerator._poc_template = self._load_template("gen-poc.txt") @@ -117,31 +124,98 @@ def __init__(self, cp_meta: CPMetadata, beepseed: BeepSeed): if PromptGenerator._script_template is None: PromptGenerator._script_template = self._load_template("gen-script.txt") + def _get_file_content_with_context(self, file_path: str, line_numbers: list[int]) -> str: + """ + Get file content - either full file or snippets around line numbers for large files. + + Args: + file_path: Path to the source file + line_numbers: Line numbers from stack trace frames + + Returns: + Formatted file content with line numbers + """ + # If trim_context is False, always return the full file + if not self.trim_context: + return cat_n(file_path) + + # Define line count threshold for when to use snippets + LINE_COUNT_THRESHOLD = 150 + + # Check line count + try: + with open(file_path) as f: + line_count = sum(1 for _ in f) + except (FileNotFoundError, OSError): + line_count = 0 + + if line_count < LINE_COUNT_THRESHOLD: + # Small file: show entire content + return cat_n(file_path) + + # Large file: show merged snippets around line numbers + if not line_numbers: + return cat_n(file_path) + + line_nums = sorted(line_numbers) + + # Merge overlapping or nearby ranges (within 10 lines) + CONTEXT_LINES = 50 + MERGE_DISTANCE = CONTEXT_LINES * 2 + + merged_ranges = [] + for line_num in line_nums: + start = max(1, line_num - CONTEXT_LINES) + end = line_num + CONTEXT_LINES + + # Check if this range overlaps or is close to the last merged range + if merged_ranges and start <= merged_ranges[-1][1] + MERGE_DISTANCE: + # Merge with previous range + merged_ranges[-1] = ( + merged_ranges[-1][0], + max(merged_ranges[-1][1], end), + ) + else: + # Add new range + merged_ranges.append((start, end)) + + # Read file content for the merged ranges + snippets = [] + for start, end in merged_ranges: + snippet = cat_n_at_line( + file_path, (start + end) // 2, context_lines=(end - start) // 2 + ) + snippets.append(snippet) + + if snippets: + return "\n...\n".join(snippets) + else: + return cat_n(file_path) + def get_code_files(self) -> str: if not self.beepseed.stack_trace: return "No stack trace available to extract code files" - # Find the stack frames that has CP project source files - file_paths = set() + # Find the stack frames that have CP project source files and collect line numbers + file_line_ranges = {} # file_path -> list of line numbers for frame in self.beepseed.stack_trace: source_path = self.cp_meta.resolve_frame_to_file_path(frame) if source_path: - file_paths.add(source_path) + line_num = frame.get("line_num", -1) + if line_num > 0: + if source_path not in file_line_ranges: + file_line_ranges[source_path] = [] + file_line_ranges[source_path].append(line_num) - if not file_paths: + if not file_line_ranges: return "No relevant source code files found in stack trace" formatted_files = [] - for file_path in sorted(file_paths): - """ - For each file: - ``` - // File: /path/to/file - cat -n style content - ``` - """ + for file_path in sorted(file_line_ranges.keys()): file_header = f"// File: {file_path}" - file_content = cat_n(file_path) + file_content = self._get_file_content_with_context( + file_path, file_line_ranges[file_path] + ) formatted_file = f"```\n{file_header}\n{file_content}\n```" formatted_files.append(formatted_file) diff --git a/crs/expkit/expkit/sinkpoint_beep/prompts/gen-poc.txt b/crs/expkit/expkit/sinkpoint_beep/prompts/gen-poc.txt index 4a335d1bb..deec66333 100644 --- a/crs/expkit/expkit/sinkpoint_beep/prompts/gen-poc.txt +++ b/crs/expkit/expkit/sinkpoint_beep/prompts/gen-poc.txt @@ -6,36 +6,44 @@ Relevant code files: PROMPT_CODE_FILES -Potential vulnerable target (refer as sinkpoint in following) basics: +Potential vulnerable target (referred to as sinkpoint hereafter) basics: ```java // PROMPT_SINKPOINT_FILE_PATH, PROMPT_SINKPOINT_LINENO ... PROMPT_SINKPOINT_LINE_CONTENT -// PROMPT_SINKPOINT_VULN_DESC ... ``` +PROMPT_SINKPOINT_VULN_DESC -Input whose execution reaches the sinkpoint but not poc: +The following input reaches the sinkpoint, but does not trigger the sanitizer: ```bash # hexdump PROMPT_BEEPSEED_HEXDUMP +``` -# stacktrace at reaching point +``` +# stacktrace of the above input when reaching the sinkpoint PROMPT_BEEPSEED_STACKTRACE ``` Task: -We are conducting AI-driven security testing, our Jazzer fuzzer has produced an input that leads the execution to reach the sinkpoint shown above. +We are conducting AI-driven security testing; our fuzzer Jazzer has produced an input that leads the execution to reach the sinkpoint shown above. -Follow the steps below and generate the poc: +Follow the steps below to generate the POC: -- analyze why and how the sinkpoint is potentially vulnerable to that kind of vulnerability -- infer what kind of input can leads the execution to an exploitation stage -- considering the conditions/constraints along the execution path, construct the actual poc +- Analyze why and how the sinkpoint is potentially vulnerable to that kind of vulnerability. +- Infer what kind of input can lead the execution to exploit the vulnerability. +- Considering the conditions/constraints along the execution path, construct the actual POC. Note that: -- the focus of the poc is to trigger the targeted Jazzer sanitizer -- in your response, the generated poc content should be wrapped by ``` and in hex string format, e.g., ```000102...```, so I can extract it and directly pass to bytes.fromhex +- The focus of the POC is to trigger the targeted Jazzer sanitizer. +- Provide your solution with the appropriate tool call + +FuzzedDataProvider byte layout (if the harness uses FDP): +- Front cursor advances FORWARD for `consumeString/Bytes` and their `Remaining` variants. +- Back cursor advances BACKWARD for all integer/boolean/char/float/enum/pickValue methods. +- Multiple back-side calls read the tail IN REVERSE CALL ORDER: the first `consumeInt()` takes the LAST 4 bytes, the next takes the 4 before that, etc. +- Ranged methods (`consumeInt(min,max)`, `consumeEnum`, `pickValue`) consume only enough bytes for the cardinality, not the natural width. diff --git a/crs/expkit/expkit/sinkpoint_beep/prompts/gen-script.txt b/crs/expkit/expkit/sinkpoint_beep/prompts/gen-script.txt index 8a2fabbf6..243fe5e37 100644 --- a/crs/expkit/expkit/sinkpoint_beep/prompts/gen-script.txt +++ b/crs/expkit/expkit/sinkpoint_beep/prompts/gen-script.txt @@ -41,6 +41,12 @@ Format of the generated script: - To allow our tool correctly utilize the script, it must implement a function named `gen_one_seed` that returns one PoC for that fuzzing harness in bytes. An example is shown below. - This is just a suspicious sinkpoint, not a confirmed vulnerability. If you confirmed it is not exploitable after reasoning, you should just return the below unimplemented example, with a comment "NOT EXPLOITABLE" in `gen_one_seed` function. +FuzzedDataProvider byte layout (if the harness uses FDP): +- Front cursor advances FORWARD for `consumeString/Bytes` and their `Remaining` variants. +- Back cursor advances BACKWARD for all integer/boolean/char/float/enum/pickValue methods. +- Multiple back-side calls read the tail IN REVERSE CALL ORDER: the first `consumeInt()` takes the LAST 4 bytes, the next takes the 4 before that, etc. +- Ranged methods (`consumeInt(min,max)`, `consumeEnum`, `pickValue`) consume only enough bytes for the cardinality, not the natural width. + def gen_one_seed() -> bytes: # for real usage # TODO: implement this to do poc generation diff --git a/crs/expkit/expkit/utils.py b/crs/expkit/expkit/utils.py index 1faa9c023..a4f005c2e 100755 --- a/crs/expkit/expkit/utils.py +++ b/crs/expkit/expkit/utils.py @@ -72,3 +72,22 @@ def get_usable_cpu_id(): logger.info("Defaulting to CPU 0") return 0 + + +def get_with_model_provider(model_name: str) -> str: + """Infer the model provider from the model name.""" + model_name = model_name.lower() + if model_name.startswith("vertex"): + return "vertex_ai/openai/" + model_name.split("/")[-1] + "-maas" + elif model_name.startswith("gpt") or model_name.startswith("o1") or model_name.startswith("o3-"): + return "openai/" + model_name + elif model_name.startswith("claude-"): + return "anthropic/" + model_name + elif model_name.startswith("gemini-"): + return "gemini/" + model_name + elif model_name.startswith("grok-"): + return "xai/" + model_name + elif model_name.startswith("zai-"): + return "vertex_ai/" + model_name + "-maas" + else: + return "unknown" diff --git a/crs/expkit/requirements.txt b/crs/expkit/requirements.txt index 1e9339760..3a0ddd00e 100644 --- a/crs/expkit/requirements.txt +++ b/crs/expkit/requirements.txt @@ -1,3 +1,7 @@ -tenacity==9.0.0 -litellm==1.65.0 +tenacity==8.5.0 hexdump2==1.2.2 +langchain==0.3.27 +langchain-community==0.3.31 +langchain-core==0.3.79 +langchain-litellm @ git+https://github.com/Akshay-Dongare/langchain-litellm.git@3d6d6752005d11ab2e946990b34fed14ce26cb2d +pydantic==2.8.0 diff --git a/crs/filtering-agent/.gitignore b/crs/filtering-agent/.gitignore new file mode 100644 index 000000000..91d464b91 --- /dev/null +++ b/crs/filtering-agent/.gitignore @@ -0,0 +1,34 @@ +# Python bytecode +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +dist/ +build/ +*.egg-info/ + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Virtual environments +venv/ +env/ +ENV/ + +# IDE specific files +.idea/ +.vscode/ +*.swp +*.swo + +# Logs +*.log diff --git a/crs/filtering-agent/README.md b/crs/filtering-agent/README.md new file mode 100644 index 000000000..c89093688 --- /dev/null +++ b/crs/filtering-agent/README.md @@ -0,0 +1,174 @@ +# Sink Picker + +An AI-powered tool that analyzes potential vulnerability sink points in Java code and selects the most promising ones for further analysis. + +## Overview + +The Sink Picker uses an LLM-powered agent to evaluate vulnerability sink locations identified by static analysis tools. It analyzes the code context, data flow, and exploitability factors to select the most promising sink point in each file. + +## Installation + +```bash +pip3 install -r requirements.txt +pip3 install -e . +``` + +## Environment Setup + +Set the required environment variables: + +```bash +export LITELLM_KEY=your_api_key_here +export AIXCC_LITELLM_HOSTNAME=your_litellm_endpoint +``` + +## Usage + +```bash +python -m sinkpicker.pick_sinks \ + /path/to/input_sinks.json \ + /path/to/output_sinks.json \ + --metadata /path/to/cpmetadata.json \ + --harness-cwe-pairs "harness1:CWE-022,harness2:CWE-089,harness3:CWE-022" \ + --call-graph /path/to/call-graph.json \ + --workdir /path/to/working/directory \ + --verbose +``` + +### Arguments + +- `input`: Path to input JSON file with sink candidates (required) +- `output`: Path to save output JSON file with filtering decisions (required) +- `--metadata`: Path to CP metadata JSON file (required) +- `--harness-cwe-pairs`: Comma-separated list of harness:CWE pairs, e.g., "harness1:CWE-022,harness2:CWE-089". Each pair is analyzed independently. (required) +- `--call-graph`: Path to call graph JSON file in Joern format (required) +- `--workdir`: Working directory for agent artifacts (optional) +- `--gen-model`: LLM model used by the filtering agent (default: gpt-5) +- `--max-iterations`: Maximum agent iterations (default: 15) +- `--temperature`: LLM temperature (default: 0.0) +- `--max-workers`: Maximum parallel workers for processing files (default: 10) +- `--verbose`: Enable verbose logging + +### Advanced Options + +```bash +# Pick a different model +--gen-model claude-3-7-sonnet-20250219 + +# Adjust agent behavior +--max-iterations 20 +--temperature 0.1 + +# Control parallel processing (useful for managing API rate limits) +--max-workers 5 +``` + +## Input Format + +JSON array of sink candidates with harness information: +```json +[ + { + "coord": { + "line_num": 840, + "file_name": "src/main/java/Example.java", + "start_column": 30, + "end_column": 43 + }, + "id": "src/main/java/Example.java:840:30:840:43", + "harness": "harness1", + "cwe": "CWE-022", + "message": "Path traversal sink", + "filtered_out_flow": false, + "filtered_out_test": false, + "reachable": true + } +] +``` + +## Output Format + +Same format as input, with additional `unexploitable` and `in_final_result` fields: +- `unexploitable`: List of harness names where the sink is NOT exploitable (well-protected) +- `in_final_result`: List of harness names where the sink should be included in the final result set + +```json +[ + { + "coord": { + "line_num": 840, + "file_name": "src/main/java/Example.java", + "start_column": 30, + "end_column": 43 + }, + "id": "src/main/java/Example.java:840:30:840:43", + "harness": "harness1", + "cwe": "CWE-022", + "message": "Path traversal sink", + "filtered_out_flow": false, + "filtered_out_test": false, + "reachable": true, + "unexploitable": ["harness2"], + "in_final_result": ["harness1", "harness3"] + } +] +``` + +**Notes**: +- `unexploitable`: List of harnesses for which the agent determined the sink is unexploitable. Only added to sinks that were analyzed by the agent. +- `in_final_result`: List of harnesses for which the sink should be in the final result + - Includes harnesses where no agent analysis was performed + - Includes harnesses where the sink is potentially exploitable + - Includes harnesses where analysis failed (conservative approach) + - Excludes harnesses where agent determined sink is unexploitable +- Since analysis is per-harness, the same sink may be unexploitable for one harness but exploitable for another + +## How It Works + +1. **Input Processing**: Loads sink candidates and filters by harness and CWE for each pair +2. **Harness File Loading**: Searches for and reads the harness file (e.g., `JenkinsOne.java`) from the oss-fuzz directory +3. **Call Path Discovery**: Uses call graph to find path from `fuzzerTestOneInput` to sink location +4. **Parallel Analysis**: Processes multiple sinks concurrently (up to 10 by default) using a **two-stage LangGraph workflow**: + + **Stage 1 - Research Phase:** + - Agent receives initial prompt with: + - Harness file contents (e.g., `JenkinsOne.java`) + - Call path from `fuzzerTestOneInput` to the sink + - Agent explores code using tools (read_file, list_directory, search_files) + - Gathers information about data flow, input validation, and accessibility + - Call path helps trace how user input flows through the execution to reach the sink + - Runs for up to max_iterations (default: 15) + + **Stage 2 - Selection Phase:** + - Takes research findings and initial prompt as input + - Uses structured LLM output to make final decision + - Returns SinkSelection with chosen sink ID and reasoning + - Guarantees a decision (either a sink ID or 'none') + +4. **Output Generation**: Updates all sinks with `filtered_out_agent` field + +## Supported CWEs + +- CWE-022: Path Traversal +- CWE-078: OS Command Injection +- CWE-089: SQL Injection +- CWE-090: LDAP Injection +- CWE-094: Code Injection (Bean Validation & Script Injection) +- CWE-117: Log Injection +- CWE-470: Unsafe Reflection +- CWE-502: Deserialization of Untrusted Data +- CWE-611: XXE (XML External Entity) +- CWE-643: XPath Injection +- CWE-730: Regex Injection / ReDoS +- CWE-918: Server-Side Request Forgery (SSRF) + +## Architecture + +The tool consists of three main components: + +1. **pick_sinks.py**: Main orchestration logic with parallel processing +2. **picker_agent.py**: Two-stage LangGraph agent + - Research phase: ReAct agent with exploration tools + - Selection phase: Structured LLM output (Pydantic models) + - State graph manages flow between phases +3. **picker_prompt.py**: CWE-specific prompt generation with code context diff --git a/crs/filtering-agent/requirements.txt b/crs/filtering-agent/requirements.txt new file mode 100644 index 000000000..40c829930 --- /dev/null +++ b/crs/filtering-agent/requirements.txt @@ -0,0 +1,9 @@ +tenacity==8.5.0 +func-timeout==4.3.5 +langchain==0.3.27 +langchain-community==0.3.31 +langchain-core==0.3.79 +langchain-litellm @ git+https://github.com/Akshay-Dongare/langchain-litellm.git@3d6d6752005d11ab2e946990b34fed14ce26cb2d +langgraph==1.0.1 +#langgraph>=0.2.0 +pydantic==2.8.0 diff --git a/crs/filtering-agent/setup.py b/crs/filtering-agent/setup.py new file mode 100755 index 000000000..23efa99ac --- /dev/null +++ b/crs/filtering-agent/setup.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 + +from setuptools import find_packages, setup + +setup( + name="sink-picker", + version="0.1.0", + description="AI-powered sink point selection tool for vulnerability analysis", + author="Cen Zhang, Fabian Fleischer", + packages=find_packages(), + entry_points={ + "console_scripts": [ + "pick-sinks=sinkpicker.pick_sinks:main", + ], + }, + python_requires=">=3.8", + install_requires=[ + "argparse", + ], +) diff --git a/crs/filtering-agent/sinkpicker/__init__.py b/crs/filtering-agent/sinkpicker/__init__.py new file mode 100755 index 000000000..31e859398 --- /dev/null +++ b/crs/filtering-agent/sinkpicker/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python3 +"""Sink picker tool for vulnerability analysis.""" + +__version__ = "0.1.0" diff --git a/crs/filtering-agent/sinkpicker/assessor_agent.py b/crs/filtering-agent/sinkpicker/assessor_agent.py new file mode 100644 index 000000000..7c30f2596 --- /dev/null +++ b/crs/filtering-agent/sinkpicker/assessor_agent.py @@ -0,0 +1,747 @@ +#!/usr/bin/env python3 + +import fnmatch +import logging +import os +from pathlib import Path +from typing import Any, Dict, Type + +from func_timeout import func_timeout, FunctionTimedOut +from langchain.agents import AgentExecutor, create_tool_calling_agent +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import LLMResult +from langchain.tools import BaseTool +from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, MessagesPlaceholder +from langchain_litellm import ChatLiteLLM +from langgraph.graph import StateGraph, END +from litellm import cost_per_token, RateLimitError +from pydantic import BaseModel, Field +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential, before_sleep_log, after_log +from typing import Optional, TypedDict, Annotated, Sequence +import operator + +from .cpmeta import CPMetadata +from .utils import CRS_ERR_LOG, CRS_WARN_LOG, get_with_model_provider +from .assessor_prompt import ExploitabilityPromptGenerator + +CRS_ERR = CRS_ERR_LOG("exploitability_agent") +CRS_WARN = CRS_WARN_LOG("exploitability_agent") +logger = logging.getLogger(__name__) +logging.getLogger("LiteLLM").setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s - [%(threadName)s] - %(levelname)s - %(message)s") +logging.getLogger("LiteLLM").handlers[0].setFormatter(formatter) + + +class FilePathInput(BaseModel): + """Input schema for file path.""" + file_path: str = Field(description="Path to the file to read (absolute or relative)", alias="path") + + model_config = {"populate_by_name": True} + + +class UnifiedReadFileTool(BaseTool): + """Unified read file tool that checks both root directories.""" + name: str = "read_file" + description: str = "Read a file from either the OSS-Fuzz project directory or the challenge project source code. Provide either an absolute path or a relative path (will check both roots)." + args_schema: Type[BaseModel] = FilePathInput + + ossfuzz_root: str = None + source_root: str = None + + def __init__(self, ossfuzz_root: str = None, source_root: str = None): + super().__init__() + self.ossfuzz_root = ossfuzz_root + self.source_root = source_root + + def _run(self, file_path: str) -> str: + """Read file from one of the configured roots.""" + import html + logger.info(f"[Tool] read_file START: {file_path}") + + path = Path(file_path) + result = None + + def read_file_content(p: Path) -> str: + logger.info(f"[Tool] read_file: Reading file {p}") + result = p.read_text() + + return result + + # If absolute path, check if it's within one of the roots + if path.is_absolute(): + if self.ossfuzz_root and path.is_relative_to(self.ossfuzz_root): + if path.exists() and path.is_file(): + result = read_file_content(path) + else: + result = f"Error: File not found: {file_path}" + elif self.source_root and path.is_relative_to(self.source_root): + if path.exists() and path.is_file(): + result = read_file_content(path) + else: + result = f"Error: File not found: {file_path}" + else: + result = f"Error: Path {file_path} is not within allowed directories ({self.ossfuzz_root}, {self.source_root})" + + # If relative path, try both roots + else: + if self.ossfuzz_root: + ossfuzz_path = Path(self.ossfuzz_root) / file_path + if ossfuzz_path.exists() and ossfuzz_path.is_file(): + result = read_file_content(ossfuzz_path) + + if result is None and self.source_root: + source_path = Path(self.source_root) / file_path + if source_path.exists() and source_path.is_file(): + result = read_file_content(source_path) + + if result is None: + result = f"Error: File not found in any configured directory: {file_path}" + + logger.info(f"[Tool] read_file END: {len(result)} chars") + return result + + +class DirectoryPathInput(BaseModel): + """Input schema for directory path.""" + dir_path: str = Field(description="Path to the directory to list (absolute or relative)") + + +class UnifiedListDirectoryTool(BaseTool): + """Unified list directory tool that checks both root directories.""" + name: str = "list_directory" + description: str = "List contents of a directory from either the OSS-Fuzz project directory or the challenge project source code. Provide either an absolute path or a relative path (will check both roots)." + args_schema: Type[BaseModel] = DirectoryPathInput + + ossfuzz_root: str = None + source_root: str = None + + def __init__(self, ossfuzz_root: str = None, source_root: str = None): + super().__init__() + self.ossfuzz_root = ossfuzz_root + self.source_root = source_root + + def _run(self, dir_path: str) -> str: + """List directory contents from one of the configured roots.""" + logger.info(f"[Tool] list_directory START: {dir_path}") + + path = Path(dir_path) + + result = None + + # If absolute path, check if it's within one of the roots + if path.is_absolute(): + if self.ossfuzz_root and path.is_relative_to(self.ossfuzz_root): + if path.exists() and path.is_dir(): + items = [item.name + "/" if item.is_dir() else item.name for item in path.iterdir()] + result = "\n".join(sorted(items)) + else: + result = f"Error: Directory not found: {dir_path}" + elif self.source_root and path.is_relative_to(self.source_root): + if path.exists() and path.is_dir(): + items = [item.name + "/" if item.is_dir() else item.name for item in path.iterdir()] + result = "\n".join(sorted(items)) + else: + result = f"Error: Directory not found: {dir_path}" + else: + result = f"Error: Path {dir_path} is not within allowed directories ({self.ossfuzz_root}, {self.source_root})" + + # If relative path, try both roots + else: + if self.ossfuzz_root: + ossfuzz_path = Path(self.ossfuzz_root) / dir_path + if ossfuzz_path.exists() and ossfuzz_path.is_dir(): + items = [item.name + "/" if item.is_dir() else item.name for item in ossfuzz_path.iterdir()] + result = "\n".join(sorted(items)) + + if result is None and self.source_root: + source_path = Path(self.source_root) / dir_path + if source_path.exists() and source_path.is_dir(): + items = [item.name + "/" if item.is_dir() else item.name for item in source_path.iterdir()] + result = "\n".join(sorted(items)) + + if result is None: + result = f"Error: Directory not found in any configured directory: {dir_path}" + + logger.info(f"[Tool] list_directory END: {len(result)} chars") + return result + + +class SearchPatternInput(BaseModel): + """Input schema for search pattern.""" + pattern: str = Field(description="File name pattern to search for (e.g., '*.java', 'Main.java')") + + +class UnifiedFileSearchTool(BaseTool): + """Unified file search tool that searches in both root directories.""" + name: str = "search_files" + description: str = "Search for files by name pattern in both the OSS-Fuzz project directory and the challenge project source code. Provide a search pattern (uses fnmatch; supports wildcards like *.java)." + args_schema: Type[BaseModel] = SearchPatternInput + + ossfuzz_root: str = None + source_root: str = None + + def __init__(self, ossfuzz_root: str = None, source_root: str = None): + super().__init__() + self.ossfuzz_root = ossfuzz_root + self.source_root = source_root + + def _run(self, pattern: str) -> str: + """Search for files matching pattern in configured roots.""" + logger.info(f"[Tool] search_files START: {pattern}") + + results = [] + + def search_in_directory(root: Path, pattern: str): + matches = [] + for path in root.rglob('*'): + if path.is_file() and (fnmatch.fnmatch(str(path), pattern) or + fnmatch.fnmatch(path.name, pattern)): + matches.append(str(path)) + return matches + + if self.ossfuzz_root: + ossfuzz_results = search_in_directory(Path(self.ossfuzz_root), pattern) + results.extend(ossfuzz_results) + + if self.source_root: + source_results = search_in_directory(Path(self.source_root), pattern) + results.extend(source_results) + + if results: + result = "\n".join(sorted(results)) + else: + result = f"No files found matching pattern: {pattern}" + + logger.info(f"[Tool] search_files END: {len(result)} chars, {len(results)} files") + return result + + +class ExploitabilityAssessment(BaseModel): + """Structured output schema for exploitability assessment.""" + unexploitable: bool = Field( + description="True if the sink is NOT exploitable (well-protected), False if potentially exploitable" + ) + reasoning: str = Field( + description="Brief explanation of the exploitability assessment (2-3 sentences)" + ) + + +class AgentState(TypedDict): + """State for the two-stage exploitability assessment agent.""" + # Input + initial_prompt: str + + # Research phase + research_notes: Annotated[list[str], operator.add] + research_iterations: int + used_fallback_report: bool + + # Assessment phase + assessment: Optional[ExploitabilityAssessment] + + # Metadata + intermediate_steps: Annotated[list, operator.add] + + +class LLMInteractionLogger(BaseCallbackHandler): + """Captures every LLM prompt and response for debugging.""" + + def __init__(self): + super().__init__() + self.interactions = [] + self._pending_prompts = {} # run_id -> prompts + + def on_llm_start(self, serialized: Dict[str, Any], prompts: list[str], **kwargs: Any) -> None: + run_id = kwargs.get("run_id", "unknown") + self._pending_prompts[str(run_id)] = prompts + + def on_chat_model_start(self, serialized: Dict[str, Any], messages: list, **kwargs: Any) -> None: + run_id = str(kwargs.get("run_id", "unknown")) + formatted = [] + for msg_group in messages: + for msg in msg_group: + role = getattr(msg, "type", "unknown") + content = getattr(msg, "content", str(msg)) + tool_calls = getattr(msg, "tool_calls", None) + entry = f"[{role}] {content}" + if tool_calls: + entry += f"\n tool_calls: {tool_calls}" + formatted.append(entry) + self._pending_prompts[run_id] = formatted + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + run_id = str(kwargs.get("run_id", "unknown")) + prompts = self._pending_prompts.pop(run_id, ["(prompts not captured)"]) + + response_texts = [] + for gens in response.generations: + for gen in gens: + text = getattr(gen, "text", "") + msg = getattr(gen, "message", None) + if msg: + content = getattr(msg, "content", "") + tool_calls = getattr(msg, "tool_calls", None) + if content: + response_texts.append(str(content)) + if tool_calls: + response_texts.append(f"tool_calls: {tool_calls}") + elif text: + response_texts.append(text) + + self.interactions.append({ + "prompts": prompts, + "response": response_texts, + }) + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + run_id = str(kwargs.get("run_id", "unknown")) + prompts = self._pending_prompts.pop(run_id, ["(prompts not captured)"]) + self.interactions.append({ + "prompts": prompts, + "response": [f"ERROR: {error}"], + }) + + def dump(self, path: Path) -> None: + """Write all interactions to a file.""" + with open(path, "w") as f: + for i, interaction in enumerate(self.interactions, 1): + f.write(f"{'='*80}\n") + f.write(f"LLM CALL #{i}\n") + f.write(f"{'='*80}\n\n") + f.write("--- PROMPT ---\n\n") + for prompt in interaction["prompts"]: + f.write(str(prompt)) + f.write("\n\n") + f.write("--- RESPONSE ---\n\n") + for resp in interaction["response"]: + f.write(str(resp)) + f.write("\n\n") + + +class LiteLLMCostTracker(BaseCallbackHandler): + """Custom callback handler to track exact costs from LiteLLM server.""" + model: str + + def __init__(self, model: str): + super().__init__() + self.model = model + self.total_cost = 0.0 + self.total_tokens = 0 + self.total_prompt_tokens = 0 + self.total_completion_tokens = 0 + self.llm_calls = [] + + def on_llm_end( + self, + response: LLMResult, + **kwargs: Any, + ) -> None: + """Calculate cost when LLM call completes.""" + + input_tokens = 0 + output_tokens = 0 + + for gens in response.generations: + for gen in gens: + if gen.message.usage_metadata: + meta = gen.message.usage_metadata + input_tokens += meta.get('input_tokens', 0) + output_tokens += meta.get('output_tokens', 0) + + try: + prompt_cost, completion_cost = cost_per_token(self.model, input_tokens, output_tokens) + except Exception as e: + logger.warning(f"Error calculating cost for model {self.model}, trying with provider: {e}") + model_with_provider = get_with_model_provider(self.model) + try: + prompt_cost, completion_cost = cost_per_token(model_with_provider, input_tokens, output_tokens) + except Exception as e2: + logger.error(f"Failed to calculate cost with provider for model {model_with_provider}: {e2}") + prompt_cost, completion_cost = 0.0, 0.0 # Fallback to zero cost if we can't calculate + + # Update totals + self.total_cost += prompt_cost + completion_cost + self.total_tokens += input_tokens + output_tokens + self.total_prompt_tokens += input_tokens + self.total_completion_tokens += output_tokens + + # Store call details + call_info = { + "model": self.model, + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "cost": prompt_cost + completion_cost, + } + self.llm_calls.append(call_info) + + def get_summary(self) -> Dict[str, Any]: + """Get cost tracking summary.""" + return { + "total_cost": round(self.total_cost, 6), + "total_tokens": self.total_tokens, + "total_prompt_tokens": self.total_prompt_tokens, + "total_completion_tokens": self.total_completion_tokens, + "num_llm_calls": len(self.llm_calls), + "cost_breakdown": self.llm_calls + } + + +class ExploitabilityAssessor: + """Two-stage LangGraph agent for exploitability assessment. + + Stage 1 (Research): Agent explores code with tools, gathering information + Stage 2 (Assessment): Agent uses structured output to assess exploitability + """ + llm: ChatLiteLLM + llm_with_structure: ChatLiteLLM + cost_tracker: LiteLLMCostTracker + tools: list[BaseTool] + graph: StateGraph + work_dir: Path = None + assessment_result: Optional[ExploitabilityAssessment] = None + max_research_iterations: int = 10 + intermediate_steps: list = [] + research_notes: list = [] + + def __init__( + self, + model: str, + temperature: float | None, + cp_meta: CPMetadata, + work_dir: Path, + max_iterations: int = 15, + ): + api_key = os.environ.get("LITELLM_KEY") + base_url = os.environ.get("AIXCC_LITELLM_HOSTNAME") + + self.work_dir = work_dir + self.max_research_iterations = max_iterations # Use max_iterations for research phase + + self.cost_tracker = LiteLLMCostTracker(model) + self.interaction_logger = LLMInteractionLogger() + + llm_kwargs = dict( + model=f"litellm_proxy/{model}", + api_key=api_key, + api_base=base_url, + callbacks=[self.cost_tracker, self.interaction_logger], + request_timeout=30, + ) + if temperature is not None: + llm_kwargs["temperature"] = temperature + self.llm = ChatLiteLLM(**llm_kwargs) + self.llm.streaming = False + + # Create LLM with structured output for assessment phase + self.llm_with_structure = self.llm.with_structured_output(ExploitabilityAssessment) + + # Create exploration tools + self.tools = self._create_tools(cp_meta) + + # Build the two-stage graph + self.graph = self._build_graph() + + def _create_tools(self, cp_meta: CPMetadata) -> list[BaseTool]: + """Create exploration tools for the research phase.""" + tools = [] + if cp_meta is not None: + ossfuzz_root = str(cp_meta.get_proj_path()) if cp_meta.get_proj_path() else None + source_root = str(cp_meta.get_repo_src_path()) if cp_meta.get_repo_src_path() else None + + if ossfuzz_root or source_root: + tools.append(UnifiedReadFileTool(ossfuzz_root=ossfuzz_root, source_root=source_root)) + tools.append(UnifiedListDirectoryTool(ossfuzz_root=ossfuzz_root, source_root=source_root)) + tools.append(UnifiedFileSearchTool(ossfuzz_root=ossfuzz_root, source_root=source_root)) + logger.info(f"Added exploration tools with roots: ossfuzz={ossfuzz_root}, source={source_root}") + else: + logger.warning(f"{CRS_WARN} Could not determine source roots, no exploration tools added") + return tools + + def _build_graph(self) -> StateGraph: + """Build the two-stage LangGraph: research -> assessment.""" + workflow = StateGraph(AgentState) + + # Add nodes + workflow.add_node("research", self._research_node) + workflow.add_node("assessment", self._assessment_node) + + # Set entry point + workflow.set_entry_point("research") + + # Add edges: research always goes to assessment + workflow.add_edge("research", "assessment") + workflow.add_edge("assessment", END) + + return workflow.compile() + + def _generate_fallback_report(self, initial_prompt: str, intermediate_steps: list) -> str: + """Generate a report from intermediate steps when agent doesn't produce final output. + + This is used as a fallback when the research agent reaches max_iterations without + generating a comprehensive report. + """ + # Extract CWE from the initial prompt to include vulnerability context + import re + cwe_match = re.search(r'\*\*(CWE-\d+):', initial_prompt) + cwe = cwe_match.group(1) if cwe_match else None + + # Build report header with vulnerability context + report_lines = [ + "# Research Report (Generated from Tool Observations)", + "", + f"The research phase reached max iterations ({self.max_research_iterations}) without generating a final report.", + "Below is a summary of the exploration activities:", + "" + ] + + # Include vulnerability description and exploitability factors if CWE is found + if cwe: + report_lines.extend([ + ExploitabilityPromptGenerator.get_cwe_description(cwe), + "", + ExploitabilityPromptGenerator.get_cwe_exploitability_factors(cwe), + "" + ]) + + if not intermediate_steps: + return "\n".join(report_lines + ["", "No research tools were used. Assessment based on initial code context only."]) + + # Continue with tool observations + report_lines.append("") + + # Build report with agent reasoning and tool observations + for i, (action, observation) in enumerate(intermediate_steps, 1): + tool_name = getattr(action, "tool", "unknown") + tool_input = getattr(action, "tool_input", {}) + reasoning = getattr(action, "log", "") + + report_lines.append(f"## Step {i}: {tool_name}") + report_lines.append("") + + # Add agent's reasoning/thought process + if reasoning: + report_lines.append("**Agent Reasoning:**") + report_lines.append(reasoning.strip()) + report_lines.append("") + + # Add tool input + if isinstance(tool_input, dict): + input_str = ", ".join([f"{k}={v}" for k, v in tool_input.items()]) + else: + input_str = str(tool_input) + report_lines.append(f"**Tool Input:** {input_str}") + report_lines.append("") + + # Add observation + obs_str = str(observation) + report_lines.append(f"**Observation:**") + report_lines.append(obs_str) + report_lines.append("") + report_lines.append("---") + report_lines.append("") + + # Add a note about incomplete analysis + report_lines.extend([ + "## Notes", + "", + "- This report was auto-generated because the research agent did not produce a final summary", + "- The assessment phase will use the above tool observations for exploitability analysis", + "- Consider increasing max_iterations if this occurs frequently" + ]) + + return "\n".join(report_lines) + + @retry( + wait=wait_exponential(multiplier=1, min=4, max=60), + stop=stop_after_attempt(5), + retry=retry_if_exception_type((RateLimitError, FunctionTimedOut, KeyError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + after=after_log(logger, logging.INFO) + ) + def _research_node(self, state: AgentState) -> AgentState: + """Research phase: explore code with tools to gather information.""" + logger.info("[Research Phase] Starting code exploration...") + + if not self.tools: + logger.info("[Research Phase] No tools available, skipping research") + return { + "research_notes": ["No exploration tools available - using only prompt context"], + "research_iterations": 0 + } + + # Create research agent with tools + research_prompt = ChatPromptTemplate.from_messages([ + ("system", + "You are assessing the exploitability of a vulnerability sink point. " + "Your goal is to gather information about this sink to determine if it is exploitable.\n\n" + "Use the available tools to explore the code and gather relevant information. " + "Focus on: data flow, input validation, accessibility, and exploitability.\n\n" + "After exploring, you MUST provide a comprehensive research report with:\n" + "1. Key findings about data flow (where does the input come from?)\n" + "2. Input validation/sanitization observed (or lack thereof)\n" + "3. Method accessibility (public API? internal only?)\n" + "4. Relevant code snippets that support your findings\n" + "5. Any security controls or mitigations in place\n\n" + "Be thorough and include specific details and code references."), + ("user", "{initial_prompt}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ]) + + research_agent = create_tool_calling_agent(self.llm, self.tools, research_prompt) + research_executor = AgentExecutor( + agent=research_agent, + tools=self.tools, + max_iterations=self.max_research_iterations, + return_intermediate_steps=True, + handle_parsing_errors=True, + verbose=False # Enable to see which tool/LLM call is hanging + ) + + # Run research with timeout + logger.info("[Research Phase] Invoking research executor (timeout=150s)") + result = func_timeout(150, research_executor.invoke, args=({"initial_prompt": state["initial_prompt"]},)) + logger.info("[Research Phase] Research executor completed") + + # Get the agent's final output (the comprehensive report) + agent_output = result.get("output", "") + + # Also capture intermediate steps for reference + intermediate_steps = result.get("intermediate_steps", []) + + # Use the agent's final output as the research note (comprehensive report) + used_fallback = False + if agent_output and "Agent stopped due to max iterations" not in agent_output: + notes = [agent_output] + else: + used_fallback = True + if not agent_output: + reason = "empty output" + elif "Agent stopped due to max iterations" in agent_output: + reason = f"hit max iterations ({self.max_research_iterations})" + else: + reason = "unknown" + logger.warning( + f"{CRS_WARN} Research agent did not generate final report " + f"(reason: {reason}, tool_calls: {len(intermediate_steps)}, " + f"output_length: {len(agent_output)}), creating fallback from tool observations" + ) + if agent_output: + logger.debug(f"[Research Phase] Raw agent output: {agent_output[:500]}") + notes = [self._generate_fallback_report(state["initial_prompt"], intermediate_steps)] + + logger.info( + f"[Research Phase] Completed with {len(intermediate_steps)} tool calls, " + f"report: {len(notes[0])} chars{' (FALLBACK)' if used_fallback else ''}" + ) + + return { + "research_notes": notes, + "research_iterations": len(intermediate_steps), + "intermediate_steps": intermediate_steps, + "used_fallback_report": used_fallback, + } + + @retry( + wait=wait_exponential(multiplier=1, min=4, max=60), + stop=stop_after_attempt(5), + retry=retry_if_exception_type((RateLimitError, FunctionTimedOut, KeyError)), + before_sleep=before_sleep_log(logger, logging.WARNING), + after=after_log(logger, logging.INFO) + ) + def _assessment_node(self, state: AgentState) -> AgentState: + """Assessment phase: use structured output to assess exploitability.""" + logger.info("[Assessment Phase] Making final exploitability assessment...") + + # Build assessment prompt with research findings + research_summary = "\n".join(state.get("research_notes", [])) + + assessment_prompt = ( + f"{state['initial_prompt']}\n\n" + f"## Research Findings\n\n" + f"{research_summary if research_summary else 'No additional research performed.'}\n\n" + f"## Final Assessment\n\n" + f"Based on the code context and research findings above, decide whether this sink is likely exploitable.\n" + f"Set unexploitable=True if the sink is very unlikely exploitable.\n" + f"Set unexploitable=False if the sink is possibly exploitable.\n" + f"Provide your assessment in structured format." + ) + + # Get structured assessment with timeout + logger.info("[Assessment Phase] Invoking LLM with structured output (timeout=60s)") + assessment = func_timeout(60, self.llm_with_structure.invoke, args=(assessment_prompt,), kwargs={'tool_choice': 'auto'}) + logger.info("[Assessment Phase] LLM call completed") + self.assessment_result = assessment + + if self.assessment_result is None: + logger.error(f"{CRS_ERR} Assessment phase failed to produce a result") + else: + logger.info(f"[Assessment Phase] Result: {'NOT exploitable' if assessment.unexploitable else 'Potentially exploitable'}") + + return {"assessment": assessment} + + def run(self, query: str) -> Dict: + """Run the two-stage exploitability assessment agent. + + Stage 1: Research phase with exploration tools + Stage 2: Assessment phase with structured output + """ + logger.info("[Agent] Starting agent execution") + + # Initialize state + initial_state = { + "initial_prompt": query, + "research_notes": [], + "research_iterations": 0, + "used_fallback_report": False, + "assessment": None, + "intermediate_steps": [] + } + + logger.info("[Agent] Invoking LangGraph workflow") + # Run the graph + final_state = self.graph.invoke(initial_state) + logger.info("[Agent] LangGraph workflow completed") + + # Store results for logging even if assessment fails + self.intermediate_steps = final_state.get("intermediate_steps", []) + self.research_notes = final_state.get("research_notes", []) + used_fallback = final_state.get("used_fallback_report", False) + + # Validate that assessment was generated + if not self.assessment_result: + logger.error(f"{CRS_ERR} Assessment phase failed to produce a result") + raise RuntimeError("Assessment result is None - agent failed to generate assessment") + + logger.info( + f"[Agent] Agent execution completed successfully" + f"{' (used fallback report)' if used_fallback else ''}" + ) + # Return result in expected format + return { + "output": f"Exploitability: {'NOT exploitable' if self.assessment_result.unexploitable else 'Potentially exploitable'}", + "assessment": self.assessment_result, + "intermediate_steps": self.intermediate_steps, + "research_notes": self.research_notes, + "used_fallback_report": used_fallback, + } + + def get_exploitability_assessment(self) -> bool: + """Get the exploitability assessment (True = NOT exploitable, False = potentially exploitable).""" + if self.assessment_result: + return self.assessment_result.unexploitable + return False # Default to potentially exploitable if no assessment + + def get_assessment_reasoning(self) -> str | None: + """Get the reasoning for the exploitability assessment.""" + if self.assessment_result: + return self.assessment_result.reasoning + return None + + def get_cost_summary(self) -> str: + summary = self.cost_tracker.get_summary() + return (f"Total Cost: ${summary['total_cost']}, " + f"Total Tokens: {summary['total_tokens']} " + f"(Prompt: {summary['total_prompt_tokens']}, " + f"Completion: {summary['total_completion_tokens']}), " + f"LLM Calls: {summary['num_llm_calls']}") diff --git a/crs/filtering-agent/sinkpicker/assessor_prompt.py b/crs/filtering-agent/sinkpicker/assessor_prompt.py new file mode 100644 index 000000000..d3478bec5 --- /dev/null +++ b/crs/filtering-agent/sinkpicker/assessor_prompt.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python3 + +import logging +from pathlib import Path + +from .cpmeta import CPMetadata +from .utils import CRS_ERR_LOG, CRS_WARN_LOG + +logger = logging.getLogger(__name__) + +CRS_ERR = CRS_ERR_LOG("exploitability_prompt") +CRS_WARN = CRS_WARN_LOG("exploitability_prompt") + + +class ExploitabilityPromptGenerator: + """Generates prompts for the exploitability assessment agent.""" + + # CWE descriptions for context + CWE_DESCRIPTIONS = { + "CWE-022": { + "name": "Path Traversal", + "description": "The software uses external input to construct a pathname that is intended to identify a file or directory located underneath a restricted parent directory, but does not properly neutralize special elements within the pathname that can cause it to resolve to a location outside of the restricted directory.", + "exploitability_factors": [ + "Cache files/status files are considered safe", + "Java class/project resources are considered safe", + "Files with a static final path are considered safe", + "If there are functions to normalize paths, check if there is any way around them; try to think of ways to bypass normalization", + "Configuration values are considered safe", + "When extracting files from an archive, it is exploitable if the file path is user-controlled (e.g., zip slip)", + "If it's in a context where the user is intended to be able to provide any file it is also safe (not exploitable due to no threat model; e.g. cli arguments)", + "It must be exploitable from the given harness; consider the context carefully" + ] + }, + "CWE-611": { + "name": "XXE (XML External Entity)", + "description": "The software processes an XML document that can contain XML entities with URIs that resolve to documents outside of the intended sphere of control, causing the product to embed incorrect documents into its output.", + "exploitability_factors": [ + "XML parser without external entity restrictions", + "User-controlled XML input", + "DTD processing enabled", + "Public API that accepts XML data" + ] + }, + "CWE-643": { + "name": "XPath Injection", + "description": "The software uses external input to dynamically construct an XPath expression used to retrieve data from an XML database, but it does not neutralize or incorrectly neutralizes that input.", + "exploitability_factors": [ + "User input directly concatenated into XPath queries", + "Lack of parameterization or input validation", + "XPath expressions used for authentication or authorization", + "Public endpoints accepting XML queries" + ] + }, + "CWE-470": { + "name": "Unsafe Reflection", + "description": "The application uses external input with reflection to select which classes or code to use, but it does not sufficiently prevent the input from selecting improper classes or code.", + "exploitability_factors": [ + "User-controlled class names or method names", + "Reflection API usage (Class.forName, Method.invoke, etc.)", + "Lack of whitelist for allowed classes/methods", + "Public API endpoints that trigger reflection" + ] + }, + "CWE-730": { + "name": "Regex Injection / ReDoS", + "description": "The software uses a regular expression with an inefficient, possibly exponential worst-case computational complexity that consumes excessive amounts of CPU cycles.", + "exploitability_factors": [ + "User-controlled regex patterns", + "Complex regex with nested quantifiers", + "Pattern.compile with user input", + "Public endpoints accepting regex patterns" + ] + }, + "CWE-089": { + "name": "SQL Injection", + "description": "The software constructs all or part of an SQL command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended SQL command.", + "exploitability_factors": [ + "String concatenation for SQL queries", + "Lack of parameterized queries or prepared statements", + "User input in WHERE clauses, ORDER BY, or other SQL components", + "Public API endpoints that interact with databases" + ] + }, + "CWE-090": { + "name": "LDAP Injection", + "description": "The software constructs all or part of an LDAP query using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended LDAP query.", + "exploitability_factors": [ + "User input directly concatenated into LDAP queries", + "Lack of input validation or escaping for LDAP special characters", + "LDAP search filters constructed from user input", + "Public API endpoints that perform LDAP queries" + ] + }, + "CWE-094": { + "name": "Code Injection", + "description": "The software constructs all or part of a code segment using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the syntax or behavior of the intended code segment.", + "exploitability_factors": [ + "User input passed to code evaluation functions (eval, script engines, expression language)", + "Bean validation constraints with user-controlled expressions", + "Template engines with user-controlled templates", + "Dynamic code compilation or execution with user input", + "Public API endpoints that process expressions or scripts" + ] + }, + "CWE-117": { + "name": "Log Injection", + "description": "The software does not neutralize or incorrectly neutralizes output that is written to logs. This allows an attacker to forge log entries or inject malicious content into logs.", + "exploitability_factors": [ + "User input directly written to logs without sanitization", + "Lack of encoding or escaping for log output", + "Log entries that could be parsed by log analysis tools", + "Potential for log forging or CRLF injection in logs", + "Public endpoints where user input appears in logs" + ] + }, + "CWE-078": { + "name": "OS Command Injection", + "description": "The software constructs all or part of an OS command using externally-influenced input from an upstream component, but it does not neutralize or incorrectly neutralizes special elements that could modify the intended OS command.", + "exploitability_factors": [ + "Input from the fuzzing harness passed to Runtime.exec or ProcessBuilder are often times exploitable", + "Shell command construction via string concatenation", + "If static final strings are used, it's likely safe; However, if command strings can be set at runtime, they may be exploitable", + "Even just controling arguments is considered exploitable", + "APIs for getting input/output streams, waiting for a process, or getting its exit value are safe; mark them as unexploitable", + "Only mark as unexploitable if you are absolutely certain that this cannot be exploited; check the context carefully" + ] + }, + "CWE-502": { + "name": "Deserialization of Untrusted Data", + "description": "The application deserializes untrusted data without sufficiently verifying that the resulting data will be valid.", + "exploitability_factors": [ + "ObjectInputStream.readObject with untrusted data", + "User-controlled serialized objects", + "Lack of deserialization filters", + "Public endpoints accepting serialized data" + ] + }, + "CWE-918": { + "name": "Server-Side Request Forgery (SSRF)", + "description": "The web server receives a URL or similar request from an upstream component and retrieves the contents of this URL, but it does not sufficiently ensure that the request is being sent to the expected destination.", + "exploitability_factors": [ + "User-controlled URLs in HTTP requests", + "Internal network access from application with user-controlled input", + "Accessing project resources is considered safe", + "Closing a resource/disconnecting is never exploitable", + "Just because a method is public does not make it exploitable; consider the context, there must be a clear threat given the harness" + ] + } + } + + @staticmethod + def get_cwe_description(cwe: str) -> str: + """ + Get the formatted CWE description for a given CWE. + + Args: + cwe: CWE identifier (e.g., "CWE-022") + + Returns: + Formatted description string with name and description + """ + cwe_info = ExploitabilityPromptGenerator.CWE_DESCRIPTIONS.get(cwe, { + "name": cwe, + "description": f"Vulnerability type {cwe}", + "exploitability_factors": [] + }) + + return f"""## Vulnerability Type +**{cwe}: {cwe_info['name']}** + +{cwe_info['description']}""" + + @staticmethod + def get_cwe_exploitability_factors(cwe: str) -> str: + """ + Get the formatted exploitability factors for a given CWE. + + Args: + cwe: CWE identifier (e.g., "CWE-022") + + Returns: + Formatted exploitability factors as a bulleted list + """ + cwe_info = ExploitabilityPromptGenerator.CWE_DESCRIPTIONS.get(cwe, { + "name": cwe, + "description": f"Vulnerability type {cwe}", + "exploitability_factors": ["User-controlled input reaching the sink"] + }) + + factors_text = "\n".join([f"- {factor}" for factor in cwe_info["exploitability_factors"]]) + + return f"""## Exploitability Factors to Consider +{factors_text}""" + + def __init__(self, cp_meta: CPMetadata, cwe: str, harness: str = None, call_graph = None): + self.cp_meta = cp_meta + self.cwe = cwe + self.harness = harness + self.call_graph = call_graph + self.cwe_info = self.CWE_DESCRIPTIONS.get(cwe, { + "name": cwe, + "description": f"Vulnerability type {cwe}", + "exploitability_factors": ["User-controlled input reaching the sink"] + }) + + def _read_harness_file(self, harness_name: str) -> str: + """ + Read the harness file from the oss-fuzz directory. + + Args: + harness_name: Name of the harness (e.g., "JenkinsOne") + + Returns: + Contents of the harness file, or None if not found + """ + harness_filename = f"{harness_name}.java" + + # Try to find the harness file in the oss-fuzz project directory + proj_path = self.cp_meta.get_proj_path() + if not proj_path: + logger.warning(f"{CRS_WARN} Project path not available in metadata") + return None + + # Search for the harness file + proj_path_obj = Path(proj_path) + + # Try common locations + search_paths = [ + proj_path_obj / harness_filename, # Root of oss-fuzz project + proj_path_obj / "src" / harness_filename, + proj_path_obj / "src" / "main" / "java" / harness_filename, + ] + + # Also search recursively in the project directory + try: + for candidate in proj_path_obj.rglob(harness_filename): + if candidate.is_file(): + logger.info(f"Found harness file: {candidate}") + with open(candidate, 'r', encoding='utf-8', errors='ignore') as f: + return f.read() + except Exception as e: + logger.warning(f"{CRS_WARN} Error searching for harness file: {e}") + + logger.warning(f"{CRS_WARN} Harness file not found: {harness_filename}") + return None + + def _find_call_path(self, sink_file: str, sink_line: int) -> str: + """ + Find a call path from fuzzerTestOneInput in the harness to the sink location. + + Args: + sink_file: Sink file path + sink_line: Sink line number + + Returns: + Formatted call path string, or message if not found + """ + if not self.call_graph or not self.harness: + return "(call graph not available)" + + # Find the fuzzerTestOneInput method in the harness + # The harness class is usually named after the harness file + harness_class = self.harness # May need adjustment based on package structure + start_node = self.call_graph.find_node_by_method(harness_class, "fuzzerTestOneInput") + + if not start_node: + logger.warning(f"{CRS_WARN} Could not find fuzzerTestOneInput in harness {self.harness}") + return f"(fuzzerTestOneInput not found in call graph for harness {self.harness})" + + # Find the node containing the sink location + end_node = self.call_graph.find_node_by_location(sink_file, sink_line) + + if not end_node: + logger.warning(f"{CRS_WARN} Could not find sink location {sink_file}:{sink_line} in call graph") + return f"(sink location {sink_file}:{sink_line} not found in call graph)" + + # Find path using BFS + path = self.call_graph.find_path_bfs(start_node, end_node) + + if not path: + logger.warning(f"{CRS_WARN} No call path found from {self.harness}.fuzzerTestOneInput to {sink_file}:{sink_line}") + return f"(no call path found from fuzzerTestOneInput to sink)" + + # Format the path + return self.call_graph.format_path(path) + + def _get_code_context(self, file_path: str, line_num: int, context_lines: int = 2) -> str: + """ + Get code context around a specific line. + + Args: + file_path: Path to the source file + line_num: Line number to get context for (1-indexed) + context_lines: Number of lines before and after to include + + Returns: + Formatted code snippet with line numbers + """ + # Try to find and read the file + full_path = None + + # Check if it's an absolute path + if Path(file_path).is_absolute(): + if Path(file_path).exists(): + full_path = Path(file_path) + else: + # Try relative to source root + if self.cp_meta.get_repo_src_path(): + candidate = Path(self.cp_meta.get_repo_src_path()) / file_path + if candidate.exists(): + full_path = candidate + + # Try relative to project path + if full_path is None and self.cp_meta.get_proj_path(): + candidate = Path(self.cp_meta.get_proj_path()) / file_path + if candidate.exists(): + full_path = candidate + + if full_path is None: + return f"(Unable to read file: {file_path})" + + try: + with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: + lines = f.readlines() + + # Calculate line range (convert to 0-indexed) + start_line = max(0, line_num - 1 - context_lines) + end_line = min(len(lines), line_num + context_lines) + + # Build formatted output + result = [] + for i in range(start_line, end_line): + line_content = lines[i].rstrip() + # Mark the target line with an arrow + if i == line_num - 1: + result.append(f" > {i+1:4d} | {line_content}") + else: + result.append(f" {i+1:4d} | {line_content}") + + return "\n".join(result) + + except Exception as e: + logger.warning(f"{CRS_WARN} Failed to read file {full_path}: {e}") + return f"(Error reading file: {e})" + + def generate_exploitability_prompt(self, sink: dict) -> str: + """ + Generate a prompt for the agent to assess exploitability of a single sink. + + Args: + sink: Sink dictionary with 'id', 'coord', 'message', 'cwe', etc. + + Returns: + A formatted prompt string + """ + coord = sink["coord"] + file_path = coord["file_name"] + line_num = coord["line_num"] + start_col = coord.get("start_column", "?") + end_col = coord.get("end_column", "?") + sink_id = sink["id"] + message = sink.get("message", "") + + # Get code context + code_context = self._get_code_context(file_path, line_num, context_lines=5) + + # Build exploitability factors + factors_text = "\n".join([f"- {factor}" for factor in self.cwe_info["exploitability_factors"]]) + + # Add harness information and file contents if available + harness_section = "" + if self.harness: + # Try to find and read the harness file + harness_file_content = self._read_harness_file(self.harness) + + if not harness_file_content: + logger.error(f"{CRS_ERR} Harness file content not found for harness: {self.harness}") + raise ValueError(f"Harness file content not found for harness: {self.harness}") + + # Find call path from fuzzerTestOneInput to sink + call_path = self._find_call_path(file_path, line_num) + + harness_section = f""" +## Test Harness +This sink is associated with the test harness: **{self.harness}** + +**Harness File: {self.harness}.java** +```java +{harness_file_content} +``` + +**One possible Call Path from fuzzerTestOneInput to Sink:** +``` +{call_path} +``` + +Consider the harness context when assessing exploitability. The harness provides insights into: +- What functionality is being tested +- Potential entry points or attack surfaces +- How user input flows into the sink: Only input provided to fuzzerTestOneInput is considered attacker-controlled. +- The call path shows how execution may flow from the harness entry point to this sink + +""" + + prompt = f"""# Exploitability Assessment Task + +## Vulnerability Type +**{self.cwe}: {self.cwe_info['name']}** + +{self.cwe_info['description']} +{harness_section} +## Exploitability Factors to Consider +{factors_text} + +## Your Task +You are assessing the exploitability of a specific vulnerability sink point. + +**Sink Information:** +- **ID:** `{sink_id}` +- **File:** `{file_path}` +- **Location:** Line {line_num}, columns {start_col}-{end_col} +- **Description:** {message} + +**Code Context (±5 lines):** +``` +{code_context} +``` + +## Instructions + +1. **Analyze this specific sink** by considering: + - Can user/external input reach this sink point? (trace back data flow) + - What is the potential impact if exploited? + - Are there any mitigating factors (security managers, sandboxing, etc.)? + +2. **Make an exploitability assessment**: + - Set `unexploitable=True` if the sink is unlikely to be exploitable + - Set `unexploitable=False` if the sink is likely exploitable + +3. **Provide a report**: + - Justify your assessment with clear reasoning + - Reference specific code lines or logic that influenced your decision + +## Assessment Criteria + +**Mark as unexploitable=True (NOT exploitable) if:** +- External input does not reach the sink point +- Input validation/sanitization is present and effective +- Security controls prevent exploitation +- You are certain that exploitation is not feasible + +**Mark as unexploitable=False (potentially exploitable) if:** +- External input can reach the sink point +- Missing or incomplete input validation +- No effective security controls are in place + +## Important Notes +- Code context (±5 lines) is provided to help with assessment +- Focus on practical exploitability, not just theoretical vulnerability +- Use the tools available to explore the codebase if you need more context +- Mark a location as exploitable if it is likely exploitable from the given harness +- Only check if the very specific line is vulnerable; we will handle other lines separately + +Begin your assessment now. +""" + return prompt diff --git a/crs/filtering-agent/sinkpicker/callgraph.py b/crs/filtering-agent/sinkpicker/callgraph.py new file mode 100644 index 000000000..57f8cf65e --- /dev/null +++ b/crs/filtering-agent/sinkpicker/callgraph.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 + +import logging +from collections import deque +from typing import Optional + +logger = logging.getLogger(__name__) + + +class CallGraph: + """Call graph representation for finding paths between methods.""" + + def __init__(self, graph_data: dict): + """ + Initialize call graph from Joern JSON format. + + Args: + graph_data: Dict with 'nodes' and 'links' keys + """ + self.nodes_by_id = {} + self.adjacency = {} # node_id -> list of target node_ids + + # Index nodes by ID + for node in graph_data.get("nodes", []): + node_id = node["id"] + self.nodes_by_id[node_id] = node["data"] + self.adjacency[node_id] = [] + + # Build adjacency list + for link in graph_data.get("links", []): + source = link["source"] + target = link["target"] + if source in self.adjacency: + self.adjacency[source].append(target) + + logger.info(f"CallGraph initialized with {len(self.nodes_by_id)} nodes") + + def find_node_by_method(self, class_name: str, func_name: str) -> Optional[int]: + """ + Find a node by class name and function name. + + Args: + class_name: Fully qualified class name + func_name: Function/method name + + Returns: + Node ID if found, None otherwise + """ + for node_id, data in self.nodes_by_id.items(): + if data.get("class_name").endswith("." + class_name) and data.get("func_name") == func_name: + return node_id + return None + + def find_node_by_location(self, file_name: str, line_num: int) -> Optional[int]: + """ + Find a node that contains the given file and line number. + + Args: + file_name: File path (may be partial) + line_num: Line number + + Returns: + Node ID if found, None otherwise + """ + # Normalize file_name for comparison + normalized_file = file_name.replace("\\", "/") + + for node_id, data in self.nodes_by_id.items(): + node_file = data.get("file_name", "").replace("\\", "/") + + # Check if file names match (handle partial paths) + if not (normalized_file in node_file or node_file in normalized_file): + continue + + # Check if line number is within method bounds + start_line = data.get("start_line") + end_line = data.get("end_line") + + if start_line is not None and end_line is not None: + if start_line <= line_num <= end_line: + return node_id + + return None + + def find_path_bfs(self, start_node_id: int, end_node_id: int, max_depth: int = 20) -> Optional[list[int]]: + """ + Find a path from start_node to end_node using BFS. + + Args: + start_node_id: Starting node ID + end_node_id: Target node ID + max_depth: Maximum search depth + + Returns: + List of node IDs forming the path, or None if no path found + """ + if start_node_id not in self.nodes_by_id or end_node_id not in self.nodes_by_id: + return None + + if start_node_id == end_node_id: + return [start_node_id] + + # BFS with path tracking + queue = deque([(start_node_id, [start_node_id])]) + visited = {start_node_id} + + while queue: + current_id, path = queue.popleft() + + # Check depth limit + if len(path) > max_depth: + continue + + # Explore neighbors + for neighbor_id in self.adjacency.get(current_id, []): + if neighbor_id == end_node_id: + return path + [neighbor_id] + + if neighbor_id not in visited: + visited.add(neighbor_id) + queue.append((neighbor_id, path + [neighbor_id])) + + return None + + def format_path(self, path: list[int]) -> str: + """ + Format a path as a human-readable string. + + Args: + path: List of node IDs + + Returns: + Formatted call path string + """ + if not path: + return "(no path found)" + + lines = [] + for i, node_id in enumerate(path): + data = self.nodes_by_id.get(node_id, {}) + class_name = data.get("class_name", "?") + func_name = data.get("func_name", "?") + file_name = data.get("file_name", "?") + start_line = data.get("start_line", "?") + + # Format: [1] ClassName.methodName() at file.java:line + lines.append(f" [{i+1}] {class_name}.{func_name}()") + lines.append(f" at {file_name}:{start_line}") + + return "\n".join(lines) + + def get_node_info(self, node_id: int) -> dict: + """Get node data by ID.""" + return self.nodes_by_id.get(node_id, {}) diff --git a/crs/filtering-agent/sinkpicker/cpmeta.py b/crs/filtering-agent/sinkpicker/cpmeta.py new file mode 100755 index 000000000..c0c75de9d --- /dev/null +++ b/crs/filtering-agent/sinkpicker/cpmeta.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +import json +import logging +from pathlib import Path +from typing import Any, Dict + +from .utils import CRS_ERR_LOG, CRS_WARN_LOG + +logger = logging.getLogger(__name__) + + +CRS_ERR = CRS_ERR_LOG("cpmeta") +CRS_WARN = CRS_WARN_LOG("cpmeta") + + +class CPMetadata: + """Class for handling CP metadata from a JSON file.""" + + def __init__(self, json_file: str): + """Initialize with the path to the metadata JSON file.""" + self.json_file = Path(json_file) + self.metadata: Dict[str, Any] = self._load_metadata() + + def _load_metadata(self) -> Dict[str, Any]: + """Load and validate the metadata JSON file.""" + if not self.json_file.exists(): + raise FileNotFoundError(f"Metadata file not found: {self.json_file}") + + try: + with open(self.json_file) as f: + metadata = json.load(f) + + if not isinstance(metadata, dict): + raise ValueError("Metadata must be a JSON object") + + return metadata + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in metadata file: {e}") + + def get_cp_name(self) -> str: + """Get the CP name from the metadata.""" + return self.metadata["cp_name"] + + def get_target_class(self, target_harness: str) -> str | None: + """Get the target class for the harness.""" + for info in self.metadata.get("harnesses", {}).values(): + if info.get("name") == target_harness: + return info.get("target_class") + + def get_classpath(self, target_harness: str) -> str | None: + """Get the classpath for the harness.""" + for info in self.metadata.get("harnesses", {}).values(): + if info.get("name") == target_harness: + return ":".join(info.get("classpath", [])) + + def get_custom_sink_conf_path(self) -> Path | None: + """Get the path to the custom sink configuration file.""" + path = self.metadata.get("custom_sink_conf_path", None) + return Path(path) if path else None + + def resolve_file_path(self, class_name: str, file_name: str): + if "." in class_name: + # fully qualified class_name -> pkg_name + # com.example.foo.Bar -> com.example.foo + pkg_name = ".".join(class_name.split(".")[:-1]) + pkg_path_part = pkg_name.replace(".", "/") + else: + return None + + pkg_file_paths = [ + Path(f) for f in self.metadata.get("pkg2files", {}).get(pkg_name, []) + ] + matching_files = [ + f for f in pkg_file_paths if f.name == file_name and pkg_path_part in str(f) + ] + # Multiple matches? Warn and use the longest one + if len(matching_files) > 1: + logger.warning( + f"{CRS_WARN} Multiple file matches for {class_name}.{file_name}: {matching_files}" + ) + return str(max(matching_files, key=lambda p: len(str(p)))) + elif len(matching_files) == 1: + return str(matching_files[0]) + else: + return None + + def resolve_frame_to_file_path(self, frame) -> str | None: + if not frame or "class_name" not in frame or "file_name" not in frame: + return None + return self.resolve_file_path(frame["class_name"], frame["file_name"]) + + def get_proj_path(self) -> Path | None: + """Get the OSS-Fuzz project directory path from metadata. + + Returns: + Path to the OSS-Fuzz project directory (harnesses, build scripts, etc.) + """ + proj_path = self.metadata.get("proj_path", None) + return Path(proj_path) if proj_path else None + + def get_repo_src_path(self) -> Path | None: + """Get the challenge project source code directory path from metadata. + + Returns: + Path to the challenge project source code directory + """ + repo_src_path = self.metadata.get("repo_src_path", None) + return Path(repo_src_path) if repo_src_path else None diff --git a/crs/filtering-agent/sinkpicker/pick_sinks.py b/crs/filtering-agent/sinkpicker/pick_sinks.py new file mode 100644 index 000000000..4214001f2 --- /dev/null +++ b/crs/filtering-agent/sinkpicker/pick_sinks.py @@ -0,0 +1,737 @@ +#!/usr/bin/env python3 + +import argparse +import json +import logging +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from .cpmeta import CPMetadata +from .assessor_agent import ExploitabilityAssessor +from .assessor_prompt import ExploitabilityPromptGenerator +from .callgraph import CallGraph +from .utils import CRS_ERR_LOG, CRS_WARN_LOG + +CRS_ERR = CRS_ERR_LOG("pick_sinks") +CRS_WARN = CRS_WARN_LOG("pick_sinks") + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - [%(threadName)s] - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +try: + from libCRS.otel import install_otel_logger + install_otel_logger(action_name="crs-java:filtering-agent") +except Exception as e: + print(f"{CRS_ERR} Failed to install OpenTelemetry logger: {e}.") + +# Force reconfigure all handlers to include thread name +# (in case OpenTelemetry or other libraries changed the format) +formatter = logging.Formatter("%(asctime)s - [%(threadName)s] - %(levelname)s - %(message)s") +for handler in logging.root.handlers: + handler.setFormatter(formatter) + + +def _run_agent_assessment( + agent: ExploitabilityAssessor, + prompt: str, + sink_id: str, +) -> tuple[bool, str, dict, bool]: + """ + Run agent assessment and return results. + + Returns: + tuple of (unexploitable, reasoning, result_dict, has_assessment_error) + """ + logger.info(f"[Thread] Running exploitability assessment for {sink_id}...") + + try: + result = agent.run(prompt) + result["_initial_prompt"] = prompt + + unexploitable = agent.get_exploitability_assessment() + reasoning = agent.get_assessment_reasoning() + + # Check if assessment was actually generated + if unexploitable is None or reasoning is None: + logger.error(f"{CRS_ERR} Assessment result incomplete for {sink_id}") + return False, "Assessment generation failed", result, True + + logger.info(f"[Thread] Exploitability assessment for {sink_id}: {'NOT exploitable' if unexploitable else 'potentially exploitable'}") + logger.info(f"[Thread] Reasoning: {reasoning}") + logger.info(f"[Thread] Cost: {agent.get_cost_summary()}") + + return unexploitable, reasoning, result, False + + except RuntimeError as e: + # Assessment generation failed — preserve research notes from the agent + logger.error(f"{CRS_ERR} Assessment generation failed for {sink_id}: {e}") + return False, str(e), { + "_initial_prompt": prompt, + "research_notes": getattr(agent, "research_notes", []), + "output": "", + }, True + + +def _save_agent_log( + workdir_path: Path, + sink_idx: int, + sink_id: str, + file_path: str, + line_num: int, + cwe: str, + max_iterations: int, + result: dict, + unexploitable: bool, + reasoning: str, + agent: ExploitabilityAssessor, + used_fallback_report: bool = False, +) -> Path: + """Save agent execution log to workdir.""" + safe_sink_id = sink_id.replace("/", "_").replace("\\", "_").replace(":", "_") + steps_file = workdir_path / f"agent_steps_{sink_idx}_{safe_sink_id}.txt" + + with open(steps_file, "w") as f: + # Header + f.write("="*80 + "\n") + f.write("EXPLOITABILITY ASSESSMENT LOG\n") + f.write("="*80 + "\n\n") + f.write(f"Sink ID: {sink_id}\n") + f.write(f"File: {file_path}\n") + f.write(f"Line: {line_num}\n") + f.write(f"Sink Index: {sink_idx}\n") + f.write(f"CWE: {cwe}\n\n") + + # Initial prompt + f.write("="*80 + "\n") + f.write("INITIAL PROMPT\n") + f.write("="*80 + "\n\n") + initial_prompt = result.get("_initial_prompt", "(prompt not available)") + f.write(initial_prompt) + f.write("\n\n") + + # Research findings (comprehensive report) + research_notes = result.get("research_notes", []) + f.write("="*80 + "\n") + if used_fallback_report: + f.write("RESEARCH REPORT (FALLBACK - agent did not generate final report)\n") + else: + f.write("RESEARCH REPORT\n") + f.write("="*80 + "\n\n") + if research_notes: + for note in research_notes: + f.write(note) + f.write("\n\n") + else: + f.write("No research report available.\n\n") + + # Final assessment + f.write("="*80 + "\n") + f.write("FINAL ASSESSMENT\n") + f.write("="*80 + "\n\n") + f.write(f"Exploitability: {'NOT exploitable (unexploitable=True)' if unexploitable else 'Potentially exploitable (unexploitable=False)'}\n\n") + f.write(f"Reasoning:\n{reasoning if reasoning else 'No reasoning provided'}\n\n") + f.write(f"Agent Output:\n{result.get('output', '(no output)')}\n\n") + + # Cost summary + cost_summary = agent.cost_tracker.get_summary() + f.write("="*80 + "\n") + f.write("COST SUMMARY\n") + f.write("="*80 + "\n\n") + f.write(f"Total Cost: ${cost_summary['total_cost']:.6f}\n") + f.write(f"Total Tokens: {cost_summary['total_tokens']}\n") + f.write(f"Prompt Tokens: {cost_summary['total_prompt_tokens']}\n") + f.write(f"Completion Tokens: {cost_summary['total_completion_tokens']}\n") + f.write(f"Number of LLM Calls: {cost_summary['num_llm_calls']}\n\n") + + if cost_summary.get('cost_breakdown'): + f.write("Per-call Breakdown:\n") + for idx, call in enumerate(cost_summary['cost_breakdown'], 1): + f.write(f" Call {idx}: {call['total_tokens']} tokens, ${call['cost']:.6f}\n") + + f.write("\n" + "="*80 + "\n") + + logger.info(f"[Thread] Saved agent steps to {steps_file}") + + # Dump raw LLM interactions for debugging + llm_log_file = workdir_path / f"llm_calls_{sink_idx}_{safe_sink_id}.txt" + agent.interaction_logger.dump(llm_log_file) + logger.info(f"[Thread] Saved LLM interactions ({len(agent.interaction_logger.interactions)} calls) to {llm_log_file}") + + return steps_file + + +def process_single_sink( + sink_idx: int, + sink: dict, + gen_model: str, + temperature: float | None, + cp_meta: CPMetadata, + workdir_path: Path, + max_iterations: int, + prompt_gen: ExploitabilityPromptGenerator, + cwe: str, + total_sinks: int, +) -> dict: + """ + Process a single sink using an agent to assess exploitability. + + Returns a dict with: + - success: bool + - sink_id: str + - unexploitable: bool (True if NOT exploitable, False if potentially exploitable) + - reasoning: str + - cost_summary: dict + - sink_index: int + - steps_file: Path + - error: str (if failed) + """ + file_path = sink["coord"]["file_name"] + sink_id = sink["id"] + line_num = sink["coord"]["line_num"] + + logger.info(f"{'='*80}") + logger.info(f"Analyzing sink #{sink_idx}: {sink_id}") + logger.info(f"Location: {file_path}:{line_num}") + + try: + # Create agent and generate prompt + agent = ExploitabilityAssessor( + model=gen_model, + temperature=temperature, + cp_meta=cp_meta, + work_dir=workdir_path, + max_iterations=max_iterations, + ) + prompt = prompt_gen.generate_exploitability_prompt(sink) + + # Run assessment + unexploitable, reasoning, result, has_assessment_error = _run_agent_assessment(agent, prompt, sink_id) + used_fallback = result.get("used_fallback_report", False) + + # Save execution log + steps_file = _save_agent_log( + workdir_path, sink_idx, sink_id, file_path, line_num, + cwe, max_iterations, result, unexploitable, reasoning, agent, + used_fallback_report=used_fallback, + ) + + # Prepare cost summary + cost_summary = agent.cost_tracker.get_summary() + cost_summary["sink_id"] = sink_id + cost_summary["sink_index"] = sink_idx + cost_summary["unexploitable"] = unexploitable + + if used_fallback: + logger.warning(f"[Thread] Sink {sink_id}: used fallback report (agent did not generate research report)") + + logger.info(f"Completed sink #{sink_idx}: {sink_id}") + return { + "success": True, + "sink_id": sink_id, + "unexploitable": unexploitable, + "reasoning": reasoning, + "cost_summary": cost_summary, + "sink_index": sink_idx, + "steps_file": steps_file, + "used_fallback_report": used_fallback, + "has_assessment_error": has_assessment_error, + } + + except Exception as e: + logger.error(f"{CRS_ERR} Error processing sink {sink_id}: {e}") + logger.error(traceback.format_exc()) + return { + "success": False, + "sink_id": sink_id, + "sink_index": sink_idx, + "error": str(e), + } + + +def run_for_single_pair( + sinks: list[dict], + harness: str, + cwe: str, + gen_model: str, + temperature: float | None, + cp_meta: CPMetadata, + call_graph: CallGraph, + workdir_path: Path, + max_iterations: int, + max_workers: int, + force_agent_analysis: bool = False, +) -> tuple[dict, dict, dict]: + """ + Run exploitability assessment for a single harness:CWE pair. + + Returns: + tuple of (unexploitable_harnesses, in_final_result_harnesses, cost_overview) + where the first two dicts map sink_id -> harness (for sinks that are unexploitable/in_final_result for this harness) + """ + logger.info(f"{'='*80}") + logger.info(f"Processing pair: {harness}:{cwe}") + logger.info(f"{'='*80}") + + # === Determine input sink set === + # Step 1: Filter by harness, CWE, and existing filters (Set A) + set_a = [ + s for s in sinks + if s.get("cwe") == cwe + and not s.get("filtered_out_flow", False) + and not s.get("filtered_out_test", False) + ] + logger.info(f"Set A (harness={harness}, CWE={cwe}, not filtered): {len(set_a)} sinks") + + if not set_a: + logger.warning(f"{CRS_WARN} No sinks found for {harness}:{cwe} after filtering.") + return {}, {}, None # No sinks for this harness + + # Step 2: Check if Set A has <= 10 sinks + if len(set_a) <= 10: + input_sink_set = set_a + logger.info(f"Set A has <= 10 sinks. Using Set A as input sink set.") + else: + # Step 3: Filter by reachability (Set B) + set_b = [s for s in set_a if s.get("reachable", False)] + logger.info(f"Set B (reachable from Set A): {len(set_b)} sinks") + + # Step 4: Determine input sink set based on Set B size + if len(set_b) == 0: + input_sink_set = set_a + logger.info(f"Set B is empty. Using Set A as input sink set.") + elif len(set_b) <= 10: + input_sink_set = set_b + logger.info(f"Set B has <= 10 sinks. Using Set B as input sink set.") + else: + # Set B has > 10 sinks, need to run agent analysis + input_sink_set = set_b + logger.info(f"Set B has > 10 sinks. Will analyze all {len(set_b)} sinks.") + + logger.info(f"{'='*80}") + logger.info(f"Input sink set determined: {len(input_sink_set)} sinks to analyze") + + # === Run Exploitability Analysis === + # If input sink set has <= 10 sinks and not forcing analysis, skip agent analysis + if len(input_sink_set) <= 10 and not force_agent_analysis: + logger.info(f"Input sink set has <= 10 sinks. Skipping agent analysis.") + # All sinks in input_sink_set are in final result for this harness (no agent analysis) + in_final_result_harnesses = {s["id"]: harness for s in input_sink_set} + return {}, in_final_result_harnesses, None + + # If forcing analysis or input sink set has > 10 sinks, run agent analysis + if force_agent_analysis and len(input_sink_set) <= 10: + logger.info(f"Input sink set has <= 10 sinks, but --force-agent-analysis is enabled. Running agent analysis.") + + # Create workdir for agent analysis artifacts + workdir_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Using pair workdir: {workdir_path}") + + # Create prompt generator + prompt_gen = ExploitabilityPromptGenerator(cp_meta, cwe, harness, call_graph) + + # Track which sinks are unexploitable/in_final_result for this harness + unexploitable_harnesses = {} # sink_id -> harness (for sinks that are unexploitable) + in_final_result_harnesses = {} # sink_id -> harness (for sinks that are in_final_result) + exploitability_by_sink_id = {} # sink_id -> unexploitable (bool) - for tracking stats + all_cost_summaries = [] + + # Error tracking + no_report_count = 0 # Missing research reports + no_assessment_count = 0 # Missing assessment results + total_errors = 0 # Total failed sinks + + # Process sinks in parallel + total_sinks = len(input_sink_set) + logger.info(f"{'='*80}") + logger.info(f"Running exploitability assessment on {total_sinks} sinks in parallel...") + + # Prepare sink processing tasks + sink_tasks = [ + (sink_idx, sink) + for sink_idx, sink in enumerate(input_sink_set, 1) + ] + + # Use ThreadPoolExecutor to process sinks in parallel + # Using threads since the work is I/O-bound (API calls) + actual_workers = min(total_sinks, max_workers) # Don't create more workers than sinks + logger.info(f"Using {actual_workers} parallel workers") + + with ThreadPoolExecutor(max_workers=actual_workers) as executor: + # Submit all tasks + future_to_sink = { + executor.submit( + process_single_sink, + sink_idx, + sink, + gen_model, + temperature, + cp_meta, + workdir_path, + max_iterations, + prompt_gen, + cwe, + total_sinks, + ): (sink_idx, sink["id"]) + for sink_idx, sink in sink_tasks + } + + # Process completed tasks as they finish + count = 1 + for future in as_completed(future_to_sink): + sink_idx, sink_id = future_to_sink[future] + try: + result = future.result() + + # By default, assume sink is in final result (for failures) + in_final_result_harnesses[sink_id] = harness + + if result["success"]: + # Track exploitability assessment + unexploitable = result["unexploitable"] + exploitability_by_sink_id[sink_id] = unexploitable + + # Update harness tracking based on exploitability + if unexploitable: + # Sink is unexploitable for this harness + unexploitable_harnesses[sink_id] = harness + # Remove from in_final_result for this harness + if sink_id in in_final_result_harnesses: + del in_final_result_harnesses[sink_id] + else: + # Sink is potentially exploitable, keep in final result + in_final_result_harnesses[sink_id] = harness + + # Collect cost summary + all_cost_summaries.append(result["cost_summary"]) + + # Track errors + if result.get("used_fallback_report", False): + no_report_count += 1 + logger.warning(f"{CRS_WARN} Used fallback research report for sink {sink_id}") + + if result.get("has_assessment_error", False): + no_assessment_count += 1 + logger.error(f"{CRS_ERR} Missing assessment result for sink {sink_id}") + + logger.info(f"Completed sink #{sink_idx} ({count}/{total_sinks}): {sink_id} - {'NOT exploitable' if unexploitable else 'Potentially exploitable'}") + else: + total_errors += 1 + # On complete failure, keep in final result (conservative) + logger.error(f"Failed to process sink #{sink_idx} ({count}/{total_sinks}): {sink_id} - Error: {result.get('error')}") + + except Exception as e: + logger.error(f"{CRS_ERR} Exception getting result for {sink_id}: {e}") + logger.error(traceback.format_exc()) + + count += 1 + + logger.info(f"{'='*80}") + logger.info(f"Completed processing all {total_sinks} sinks in parallel") + + # Report on errors + if total_errors > 0 or no_report_count > 0 or no_assessment_count > 0: + logger.warning(f"{'='*80}") + logger.warning(f"ERROR SUMMARY") + logger.warning(f"{'='*80}") + if total_errors > 0: + logger.error(f"{CRS_ERR} Complete failures (exceptions): {total_errors}/{total_sinks}") + if no_report_count > 0: + logger.warning(f"{CRS_WARN} Fallback research reports: {no_report_count}/{total_sinks}") + if no_assessment_count > 0: + logger.error(f"{CRS_ERR} Missing assessment results: {no_assessment_count}/{total_sinks}") + logger.warning(f"{'='*80}\n") + else: + logger.info(f"All {total_sinks} sinks processed successfully with reports and assessments") + + # Calculate cost summary + total_cost = sum(s["total_cost"] for s in all_cost_summaries) + total_tokens = sum(s["total_tokens"] for s in all_cost_summaries) + total_prompt_tokens = sum(s["total_prompt_tokens"] for s in all_cost_summaries) + total_completion_tokens = sum(s["total_completion_tokens"] for s in all_cost_summaries) + total_llm_calls = sum(s["num_llm_calls"] for s in all_cost_summaries) + + # Count exploitable vs unexploitable + exploitable_count = sum(1 for v in exploitability_by_sink_id.values() if not v) + unexploitable_count = sum(1 for v in exploitability_by_sink_id.values() if v) + + cost_overview = { + "harness": harness, + "cwe": cwe, + "pair": f"{harness}:{cwe}", + "total_sinks_analyzed": len(exploitability_by_sink_id), + "exploitable_sinks": exploitable_count, + "unexploitable_sinks": unexploitable_count, + "errors": { + "total_failures": total_errors, + "fallback_reports": no_report_count, + "missing_assessments": no_assessment_count, + }, + "overall_cost": round(total_cost, 6), + "overall_tokens": total_tokens, + "overall_prompt_tokens": total_prompt_tokens, + "overall_completion_tokens": total_completion_tokens, + "overall_llm_calls": total_llm_calls, + "per_sink_costs": all_cost_summaries, + } + + # Save CWE-specific cost overview + cost_file = workdir_path / "cost_overview.json" + with open(cost_file, "w") as f: + json.dump(cost_overview, f, indent=2) + logger.info(f"{'='*80}") + logger.info(f"Saved cost overview to {cost_file}") + logger.info(f"Total Cost: ${cost_overview['overall_cost']}") + logger.info(f"Total Tokens: {cost_overview['overall_tokens']}") + logger.info(f"Total LLM Calls: {cost_overview['overall_llm_calls']}") + logger.info(f"Exploitable: {exploitable_count}, Unexploitable: {unexploitable_count}") + + # Log error summary + if total_errors > 0 or no_report_count > 0 or no_assessment_count > 0: + logger.info(f"Errors: {total_errors} failures, {no_report_count} fallback reports, {no_assessment_count} missing assessments") + + return unexploitable_harnesses, in_final_result_harnesses, cost_overview + + +def run( + input_path: str, + output_path: str, + metadata_path: str, + harness_cwe_pairs: str, + call_graph_path: str, + workdir: str = None, + gen_model: str = "gpt-5", + max_iterations: int = 15, + temperature: float | None = None, + verbose: bool = False, + max_workers: int = 10, + force_agent_analysis: bool = False, +): + """Run exploitability assessment on sink candidates.""" + + try: + # Load input JSON + logger.info(f"Loading input from {input_path}") + with open(input_path, "r") as f: + sinks = json.load(f) + + logger.info(f"Loaded {len(sinks)} total sink locations") + + # Load call graph + logger.info(f"Loading call graph from {call_graph_path}") + with open(call_graph_path, "r") as f: + call_graph_data = json.load(f) + logger.info(f"Loaded call graph with {len(call_graph_data.get('nodes', []))} nodes and {len(call_graph_data.get('links', []))} edges") + + # Create CallGraph instance + call_graph = CallGraph(call_graph_data) + + # Parse harness:CWE pairs (comma-separated list) + pairs = [] + for pair_str in harness_cwe_pairs.split(","): + pair_str = pair_str.strip() + if not pair_str: + continue + if ":" not in pair_str: + logger.error(f"{CRS_ERR} Invalid harness:CWE pair format: {pair_str}") + raise ValueError(f"Invalid pair format: {pair_str}. Expected 'harness:CWE'") + harness, cwe = pair_str.split(":", 1) + pairs.append((harness.strip(), cwe.strip())) + + logger.info(f"Analyzing {len(pairs)} harness:CWE pair(s): {', '.join(f'{h}:{c}' for h, c in pairs)}") + + logger.info(f"Generation model: {gen_model}") + + # Load metadata + cp_meta = CPMetadata(metadata_path) + + # Prepare base workdir + base_workdir_path = Path(workdir) if workdir else Path.cwd() / "exploitability_workdir" + base_workdir_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Using base workdir: {base_workdir_path}") + + # Track all exploitability assessments and final result flags across all pairs + all_exploitability_by_sink_id = {} # sink_id -> unexploitable (bool) + all_in_final_result_by_sink_id = {} # sink_id -> in_final_result (bool) + all_cost_overviews = [] + + # Process each harness:CWE pair independently + for harness, cwe in pairs: + logger.info(f"{'='*80}") + logger.info(f"Starting analysis for {harness}:{cwe}") + logger.info(f"{'='*80}") + + # Prepare pair-specific workdir path (will be created only if needed) + pair_workdir_path = base_workdir_path / f"{harness}_{cwe}" + + # Run analysis for this pair + unexploitable_harnesses, in_final_result_harnesses, cost_overview = run_for_single_pair( + sinks, + harness, + cwe, + gen_model, + temperature, + cp_meta, + call_graph, + pair_workdir_path, + max_iterations, + max_workers, + force_agent_analysis, + ) + + # Merge results - aggregate harnesses into lists + for sink_id, harness_name in unexploitable_harnesses.items(): + if sink_id not in all_exploitability_by_sink_id: + all_exploitability_by_sink_id[sink_id] = [] + all_exploitability_by_sink_id[sink_id].append(harness_name) + + for sink_id, harness_name in in_final_result_harnesses.items(): + if sink_id not in all_in_final_result_by_sink_id: + all_in_final_result_by_sink_id[sink_id] = [] + all_in_final_result_by_sink_id[sink_id].append(harness_name) + if cost_overview: + all_cost_overviews.append(cost_overview) + + logger.info(f"{'='*80}") + logger.info(f"Completed analysis for {harness}:{cwe}") + logger.info(f"{'='*80}") + + # Create combined cost overview + if all_cost_overviews: + combined_cost = sum(c["overall_cost"] for c in all_cost_overviews) + combined_tokens = sum(c["overall_tokens"] for c in all_cost_overviews) + combined_prompt_tokens = sum(c["overall_prompt_tokens"] for c in all_cost_overviews) + combined_completion_tokens = sum(c["overall_completion_tokens"] for c in all_cost_overviews) + combined_llm_calls = sum(c["overall_llm_calls"] for c in all_cost_overviews) + combined_sinks_analyzed = sum(c["total_sinks_analyzed"] for c in all_cost_overviews) + combined_exploitable = sum(c["exploitable_sinks"] for c in all_cost_overviews) + combined_unexploitable = sum(c["unexploitable_sinks"] for c in all_cost_overviews) + + combined_overview = { + "pairs_analyzed": [f"{h}:{c}" for h, c in pairs], + "total_sinks_analyzed": combined_sinks_analyzed, + "exploitable_sinks": combined_exploitable, + "unexploitable_sinks": combined_unexploitable, + "overall_cost": round(combined_cost, 6), + "overall_tokens": combined_tokens, + "overall_prompt_tokens": combined_prompt_tokens, + "overall_completion_tokens": combined_completion_tokens, + "overall_llm_calls": combined_llm_calls, + "per_pair_costs": all_cost_overviews, + } + + # Save combined cost overview + combined_cost_file = base_workdir_path / "cost_overview.json" + with open(combined_cost_file, "w") as f: + json.dump(combined_overview, f, indent=2) + logger.info(f"{'='*80}") + logger.info(f"Saved combined cost overview to {combined_cost_file}") + logger.info(f"Total Cost: ${combined_overview['overall_cost']}") + logger.info(f"Total Tokens: {combined_overview['overall_tokens']}") + logger.info(f"Total LLM Calls: {combined_overview['overall_llm_calls']}") + logger.info(f"Total Sinks Analyzed: {combined_sinks_analyzed}") + logger.info(f"Exploitable: {combined_exploitable}, Unexploitable: {combined_unexploitable}") + + # Update all sinks with unexploitable and in_final_result fields + logger.info(f"{'='*80}") + logger.info(f"Updating sink entries with exploitability assessments...") + + for sink in sinks: + sink_id = sink["id"] + # Only add unexploitable field for analyzed sinks + if sink_id in all_exploitability_by_sink_id: + sink["unexploitable"] = all_exploitability_by_sink_id[sink_id] + # Add in_final_result field for sinks in input_sink_set + if sink_id in all_in_final_result_by_sink_id: + sink["in_final_result"] = all_in_final_result_by_sink_id[sink_id] + + # Write output + logger.info(f"Writing output to {output_path}") + with open(output_path, "w") as f: + json.dump(sinks, f, indent=2) + + logger.info("Exploitability assessment completed successfully") + + except Exception as e: + logger.error(f"{CRS_ERR} Fatal error: {e}") + logger.error(traceback.format_exc()) + raise + + +def main(): + """Main entry point for the exploitability assessment tool.""" + parser = argparse.ArgumentParser( + description="AI-powered exploitability assessment tool for vulnerability sink points" + ) + parser.add_argument("input", help="Path to input JSON file with sink candidates") + parser.add_argument("output", help="Path to save output JSON file") + parser.add_argument( + "--metadata", required=True, help="Path to the CP metadata JSON file" + ) + parser.add_argument( + "--harness-cwe-pairs", required=True, help="Comma-separated list of harness:CWE pairs (e.g., 'harness1:CWE-022,harness2:CWE-089,harness3:CWE-022'). Each pair is analyzed independently." + ) + parser.add_argument( + "--call-graph", required=True, help="Path to call graph JSON file (Joern format)" + ) + parser.add_argument( + "--workdir", default=None, help="Working directory for agent artifacts" + ) + parser.add_argument( + "--gen-model", + default="gpt-5", + help="LLM model used by the filtering agent. Default: gpt-5", + ) + parser.add_argument( + "--max-iterations", + type=int, + default=15, + help="Maximum number of agent iterations. Default: 15", + ) + parser.add_argument( + "--temperature", + type=float, + default=None, + help="LLM temperature for agent. If not set, uses the model's default.", + ) + parser.add_argument( + "--max-workers", + type=int, + default=10, + help="Maximum number of parallel workers for processing files. Default: 10", + ) + parser.add_argument( + "--force-agent-analysis", + action="store_true", + help="Force agent analysis even when input sink set has ≤10 sinks. Default: False (skip analysis for ≤10 sinks)", + ) + parser.add_argument( + "--verbose", action="store_true", help="Enable verbose logging" + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + run( + args.input, + args.output, + args.metadata, + args.harness_cwe_pairs, + args.call_graph, + args.workdir, + args.gen_model, + args.max_iterations, + args.temperature, + args.verbose, + args.max_workers, + args.force_agent_analysis, + ) + + +if __name__ == "__main__": + main() diff --git a/crs/filtering-agent/sinkpicker/utils.py b/crs/filtering-agent/sinkpicker/utils.py new file mode 100755 index 000000000..261102001 --- /dev/null +++ b/crs/filtering-agent/sinkpicker/utils.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 + +import logging +import os +import shlex +import sys + +logger = logging.getLogger(__name__) + +SENSITIVE_ENV_VARS = [ + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_DEPLOYMENT", + "LITELLM_KEY", + "GITHUB_TOKEN", + "GITHUB_USER", +] + + +def CRS_ERR_LOG(mod: str) -> str: + return f"CRS-JAVA-ERR-sinkpicker-{mod}" + + +def CRS_WARN_LOG(mod: str) -> str: + return f"CRS-JAVA-WARN-sinkpicker-{mod}" + + +def get_env_or_abort(env_name: str) -> str: + env_value = os.getenv(env_name) + if env_value is None: + print(f"Environment variable {env_name} is not set.", file=sys.stderr) + sys.exit(1) + return env_value + + +def sanitize_env(env: dict) -> dict: + """Remove sensitive environment variables from the environment.""" + global SENSITIVE_ENV_VARS + + new_env = env.copy() + for var in SENSITIVE_ENV_VARS: + new_env.pop(var, None) + + return new_env + + +def get_env_exports(env: dict) -> str: + """Return a string that can be used to export the environment variables.""" + return "\n".join( + f"export {k}={shlex.quote(v)}" for k, v in sanitize_env(env).items() + ) + + +def get_usable_cpu_id(): + """ + Get a usable CPU ID for binding tasks. + + Returns: + int: A CPU ID that can be used for binding tasks (0 if none found) + """ + try: + cpu_affinity = os.sched_getaffinity(0) + if cpu_affinity: + # If bound to any cores, return the first one + cpu_id = list(cpu_affinity)[0] + logger.info( + f"Using CPU ID {cpu_id} (bound to cores {sorted(list(cpu_affinity))})" + ) + return cpu_id + except Exception as e: + logger.warning(f"{CRS_ERR_LOG("cpu")} Could not determine CPU affinity: {e}") + + logger.info("Defaulting to CPU 0") + return 0 + + +def get_with_model_provider(model_name: str) -> str: + """Infer the model provider from the model name.""" + model_name = model_name.lower() + if model_name.startswith("vertex"): + return "vertex_ai/openai/" + model_name.split("/")[-1] + "-maas" + elif model_name.startswith("gpt") or model_name.startswith("o1") or model_name.startswith("o3-"): + return "openai/" + model_name + elif model_name.startswith("claude-"): + return "anthropic/" + model_name + elif model_name.startswith("gemini-"): + return "gemini/" + model_name + elif model_name.startswith("grok-"): + return "xai/" + model_name + elif model_name.startswith("zai-"): + return "vertex_ai/" + model_name + "-maas" + else: + return "unknown" diff --git a/crs/fuzzers/atl-jazzer/MODULE.bazel b/crs/fuzzers/atl-jazzer/MODULE.bazel index f7da34816..131d3337b 100644 --- a/crs/fuzzers/atl-jazzer/MODULE.bazel +++ b/crs/fuzzers/atl-jazzer/MODULE.bazel @@ -353,7 +353,7 @@ http_jar( http_jar( name = "atlantis_static-analysis", - sha256 = "be88c9ca872747a2703617a56dc260355e04be70f87d8d8c0262c0e49e825a9f", + sha256 = "36465873f6a335ef905d81ffa231e56356c532a9a806bdc1dba33c896950d951", urls = ["file:third_party/static-analysis/static-analysis-1.0.jar"], ) diff --git a/crs/fuzzers/atl-jazzer/third_party/static-analysis/static-analysis-1.0.jar b/crs/fuzzers/atl-jazzer/third_party/static-analysis/static-analysis-1.0.jar index 101cf64ce..824e72176 100644 Binary files a/crs/fuzzers/atl-jazzer/third_party/static-analysis/static-analysis-1.0.jar and b/crs/fuzzers/atl-jazzer/third_party/static-analysis/static-analysis-1.0.jar differ diff --git a/crs/javacrs_modules/__init__.py b/crs/javacrs_modules/__init__.py index f8ca75a01..f2b31ad96 100755 --- a/crs/javacrs_modules/__init__.py +++ b/crs/javacrs_modules/__init__.py @@ -1,4 +1,4 @@ -from .codeql import CodeQL, CodeQLParams +from .sinkdetection import SinkDetection, SinkDetectionParams from .concolic import ConcolicExecutor, ConcolicExecutorParams from .cpuallocator import CPUAllocator, CPUAllocatorParams from .crashmanager import CrashManager, CrashManagerParams @@ -37,8 +37,8 @@ "AtlDirectedJazzerParams", "AtlLibAFLJazzer", "AtlLibAFLJazzerParams", - "CodeQL", - "CodeQLParams", + "SinkDetection", + "SinkDetectionParams", "SeedMerger", "SeedMergerParams", "LLMPOCGenerator", diff --git a/crs/javacrs_modules/base_objs.py b/crs/javacrs_modules/base_objs.py index 721ba69e8..62a766f2c 100755 --- a/crs/javacrs_modules/base_objs.py +++ b/crs/javacrs_modules/base_objs.py @@ -274,10 +274,10 @@ def frm_dict(cls, coord_dict: Dict[str, Any]) -> "InsnCoordinate": class_name=coord_dict.get("class_name", None), method_name=coord_dict.get("method_name", None), method_desc=coord_dict.get("method_desc", None), - bytecode_offset=int(coord_dict.get("bytecode_offset") or -1), + bytecode_offset=int(v) if (v := coord_dict.get("bytecode_offset")) is not None else -1, mark_desc=coord_dict.get("mark_desc", None), file_name=coord_dict.get("file_name", None), - line_num=int(coord_dict.get("line_num") or -1), + line_num=int(v) if (v := coord_dict.get("line_num")) is not None else -1, ) def to_conf(self) -> str | None: diff --git a/crs/javacrs_modules/codeql.py b/crs/javacrs_modules/codeql.py deleted file mode 100755 index a2f0e7bd0..000000000 --- a/crs/javacrs_modules/codeql.py +++ /dev/null @@ -1,358 +0,0 @@ -#!/usr/bin/env python3 -import asyncio -import json -import os -import shlex -import shutil -import tarfile -import time -import traceback -from dataclasses import asdict -from pathlib import Path -from typing import Any, Dict, List - -import aiofiles -from libCRS import CRS, HarnessRunner, Module -from pydantic import BaseModel, Field, field_validator - -from .base_objs import Sinkpoint -from .utils import ( - CRS_ERR_LOG, - CRS_WARN_LOG, - download_file_async, - get_env_exports, - get_env_or_abort, - run_process_and_capture_output, -) -from .utils_nfs import ( - get_sarif_shared_codeql_db_done_file, - get_sarif_shared_codeql_db_path, -) - -CRS_ERR = CRS_ERR_LOG("codeql-mod") -CRS_WARN = CRS_WARN_LOG("codeql-mod") - - -class CodeQLParams(BaseModel): - enabled: bool = Field( - ..., description="**Mandatory**, true/false to enable or disable this module." - ) - - @field_validator("enabled") - def enabled_should_be_boolean(cls, v): - if not isinstance(v, bool): - raise ValueError("enabled must be a boolean") - return v - - -class CodeQL(Module): - """CodeQL analysis module for Java CRS.""" - - def __init__( - self, - name: str, - crs: CRS, - params: CodeQLParams, - run_per_harness: bool, - ): - super().__init__(name, crs, run_per_harness) - self.params = params - self.enabled = self.params.enabled - self.workdir = self.get_workdir("") / self.crs.cp.name - self.workdir.mkdir(parents=True, exist_ok=True) - self.db_dir = self.workdir / "codeql-db" - self.db_dir.mkdir(parents=True, exist_ok=True) - self.db = self.db_dir / "codeql" - self.results_file = self.workdir / "codeql_result.json" - self.results_file.parent.mkdir(parents=True, exist_ok=True) - self.tool_cwd = Path(get_env_or_abort("JAVA_CRS_SRC")) / "codeql" - - def _init(self): - pass - - async def _async_prepare(self): - pass - - async def _async_test(self, hrunner: HarnessRunner): - pass - - async def _async_get_mock_result(self, hrunner): - self.logH(hrunner, "Mock result for CodeQL") - - async def _extract_tarball(self, tarball: Path, outdir: Path): - try: - if outdir.exists(): - shutil.rmtree(outdir) - - outdir.mkdir(parents=True, exist_ok=True) - - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, lambda: tarfile.open(tarball, "r:gz").extractall(path=outdir) - ) - - self.logH(None, f"Extracted {tarball.name} to {outdir}") - return True - except Exception as e: - self.logH( - None, f"{CRS_ERR} Extraction error: {str(e)} {traceback.format_exc()}" - ) - return False - - async def _monitor_codeql_database(self) -> bool: - self.logH(None, "Monitoring for CodeQL database") - - interval = 60 - counter = 0 - - while self.crs.should_continue(): - try: - if counter % interval != 0: - counter += 1 - await asyncio.sleep(1) - continue - else: - counter += 1 - - # Check SARIF shared path - sarif_db = get_sarif_shared_codeql_db_path() - sarif_db_done = get_sarif_shared_codeql_db_done_file() - - db_files_exist = ( - sarif_db_done - and sarif_db_done.exists() - and sarif_db - and sarif_db.exists() - ) - if not db_files_exist: - continue - - self.logH(None, f"Found codeql.db at: {sarif_db}") - local_db, file_ok = await download_file_async( - src_path=sarif_db, - dst_dir=self.workdir, - logger=lambda msg: self.logH(None, msg), - err_tag=CRS_ERR, - ) - - if not file_ok: - self.logH( - None, - f"{CRS_WARN} Failed to download codeql.db {sarif_db}, retry", - ) - continue - if await self._extract_tarball(local_db, self.db_dir): - if not self.db.exists(): - self.logH( - None, - f"{CRS_ERR} Expected codeql db does not exist: {self.db}", - ) - continue - self.logH( - None, - f"codeql.db extracted successfully to {self.db}", - ) - return True - else: - self.logH( - None, - f"{CRS_ERR} Failed to extract codeql.db {local_db}, retry", - ) - except Exception as e: - self.logH( - None, f"{CRS_ERR} Monitor error: {str(e)} {traceback.format_exc()}" - ) - - self.logH(None, f"{CRS_WARN} Monitoring ended without finding database") - return False - - async def _create_command_sh( - self, - cmd: List[str], - script_name: str, - working_dir: str, - timeout: int, - cpu_list: List[int], - buffer_output: bool, - log_prefix: str, - ) -> Path: - if timeout: - cmd = ["timeout", "-s", "SIGKILL", f"{timeout}s"] + cmd - if cpu_list: - cpu_str = ",".join(map(str, cpu_list)) - cmd = ["taskset", "-c", cpu_str] + cmd - if buffer_output: - cmd = ["stdbuf", "-e", "0", "-o", "0"] + cmd - - cmd_str = " ".join(shlex.quote(str(arg)) for arg in cmd) - command_sh_content = f"""#!/bin/bash -# Env -{get_env_exports(os.environ)} -# Cmd -{f'cd "{working_dir}"' if working_dir else ''} -{cmd_str} > {str(self.workdir.resolve())}/{log_prefix}.log 2>&1 -""" - command_sh = self.workdir / script_name - async with aiofiles.open(command_sh, "w") as f_sh: - await f_sh.write(command_sh_content) - command_sh.chmod(0o755) - return command_sh - - async def _run_codeql_script( - self, cpu_list: List[int], script_name: str, log_prefix: str - ): - """Shared logic for running CodeQL scripts.""" - self.logH(None, f"Running CodeQL {script_name} with CPUs: {cpu_list}") - - try: - rest_time = self.crs.rest_time() - if rest_time <= 0: - self.logH( - None, - f"{CRS_WARN} No time left to run {script_name} (rest_time={rest_time})", - ) - return - - script_path = self.tool_cwd / script_name - if not script_path.exists(): - raise FileNotFoundError(f"CodeQL script not found: {script_path}") - - command_sh = await self._create_command_sh( - [ - str(script_path), - str(self.db), - str(self.results_file), - ], - f"command_{script_name}", - str(self.tool_cwd), - timeout=rest_time, - cpu_list=cpu_list, - buffer_output=True, - log_prefix=log_prefix, - ) - - self.logH(None, f"Running CodeQL command: {command_sh}") - log_file = self.workdir / "run.log" - - ret = await run_process_and_capture_output(command_sh, log_file) - if ret == 0: - self.logH( - None, - f"CodeQL {script_name} finished (ret: {ret}) in {time.time() - self.crs.start_time:.2f}s", - ) - else: - # Exit with != 0 and not killed by SIGKILL is unexpected - self.logH( - None, - f"{CRS_ERR} CodeQL {script_name} unexpectedly exited with ret {ret} in {time.time() - self.crs.start_time:.2f}s", - ) - return ret - - except Exception as e: - self.logH( - None, - f"{CRS_ERR} CodeQL {script_name} error: {str(e)} {traceback.format_exc()}", - ) - return -1 - - async def _codeql_query(self, cpu_list: List[int]): - """Run CodeQL queries.""" - return await self._run_codeql_script(cpu_list, "run.sh", "query") - - async def _update_sink(self, sink_dict: dict) -> bool: - # Update sinkmanager - try: - sink_dict = sink_dict["coord"] - code_coord = self.crs.query_code_coord( - sink_dict["class_name"], sink_dict["line_num"] - ) - if code_coord is None: - self.logH( - None, - f"{CRS_WARN} Filter out sinkpoint {sink_dict['class_name']}:{sink_dict['line_num']} which has no code coordinate", - ) - return False - - self.logH( - None, - f"Sinkpoint {sink_dict['class_name']}:{sink_dict['line_num']} found code coordinate: {code_coord}", - ) - sink_dict.update(asdict(code_coord)) - sink = Sinkpoint.frm_dict(sink_dict) - self.logH(None, f"CodeQL update sinkpoint to sinkmanager: {sink}") - await self.crs.sinkmanager.on_event_update_sinkpoint(sink) - return True - except Exception as e: - self.logH( - None, - f"{CRS_ERR} updating tgt sink {sink_dict} to sinkmanager: {str(e)} {traceback.format_exc()}", - ) - return False - - async def _codeql_result_parsing(self): - self.logH(None, "Parsing CodeQL results") - try: - if not self.results_file.exists(): - self.logH( - None, f"{CRS_ERR} Results file not found: {self.results_file}" - ) - return - - async with aiofiles.open(self.results_file, "r") as f: - results = await f.read() - - # Read json results - sinkpoints = json.loads(results) - if not isinstance(sinkpoints, list): - self.logH(None, f"{CRS_ERR} Invalid results format: expected list") - return - - total_sinks = len(sinkpoints) - self.logH(None, f"Found {total_sinks} sinkpoints in results") - - kept_sinks = 0 - for sink_dict in sinkpoints: - if await self._update_sink(sink_dict): - kept_sinks += 1 - - filtered_sinks = total_sinks - kept_sinks - self.logH( - None, f"Sinkpoints stats: {filtered_sinks} filtered, {kept_sinks} kept" - ) - - except Exception as e: - self.logH( - None, - f"{CRS_ERR} Result parsing error: {str(e)} {traceback.format_exc()}", - ) - - async def _run_codeql_analysis(self, cpu_list: List[int]): - try: - await self._codeql_query(cpu_list) - await self._codeql_result_parsing() - except Exception as e: - self.logH( - None, f"{CRS_ERR} Analysis error: {str(e)} {traceback.format_exc()}" - ) - - async def _async_run(self, _) -> Dict[str, Any]: - if not self.enabled: - self.logH(None, f"Module {self.name} is disabled") - return - - self.logH(None, f"Starting {self.name}") - - try: - cpu_list = await self.crs.cpuallocator.poll_allocation(None, self.name) - self.logH(None, f"Allocated CPUs: {cpu_list}") - - if await self._monitor_codeql_database(): - await self._run_codeql_analysis(cpu_list) - else: - self.logH(None, f"{CRS_WARN} Skipping analysis (no database found)") - except Exception as e: - self.logH( - None, f"{CRS_ERR} Module failed: {str(e)} {traceback.format_exc()}" - ) - - self.logH(None, f"{self.name} ended") diff --git a/crs/javacrs_modules/crashmanager.py b/crs/javacrs_modules/crashmanager.py index 78bc7be10..8a4cc1a34 100755 --- a/crs/javacrs_modules/crashmanager.py +++ b/crs/javacrs_modules/crashmanager.py @@ -55,6 +55,7 @@ def __init__( self.max_payload_per_ty = 2 self.max_frame_layer = 10 self._submitted_result_jsons = asyncio.Queue() + self._unmatched_crashes = [] def _init(self): pass @@ -277,7 +278,11 @@ async def _process_crash(self, hrunner, result_json_path: Path, crash: list): self.logH( None, f"CrashManager update sinkpoint to sinkmanager from crash: {sink}" ) - await self.crs.sinkmanager.on_event_update_sinkpoint(sink) + await self.crs.sinkmanager.on_event_update_sinkpoint(sink, source="crashmanager") + elif not exp_id.startswith("NONSEC-") and sink_coord is None: + self._unmatched_crashes.append( + (hrunner, sanitizer, crash_msg, frames, dedup_token, artifact_name, artifact_abspath) + ) async def _process_result_json(self, hrunner, result_json_path: Path): async for crash in self._get_unhandled_crashes(hrunner, result_json_path): @@ -340,6 +345,36 @@ async def _check_and_process_result_json(self, path_tuple, last_mtimes): None, f"{CRS_ERR} checking file '{path}': {e} {traceback.format_exc()}" ) + async def _retry_unmatched_crashes(self): + """Retry matching crashes that had no sinkpoint match on first processing.""" + if not self._unmatched_crashes: + return + still_unmatched = [] + for hrunner, sanitizer, crash_msg, frames, dedup_token, artifact_name, artifact_abspath in self._unmatched_crashes: + sink_coord = await self.crs.sinkmanager.match_sinkpoint(frames) + if sink_coord is not None: + sink = Sinkpoint.frm_crash( + Crash( + hrunner.harness.name, + sink_coord, + sanitizer, + crash_msg, + frames, + dedup_token, + artifact_name, + artifact_abspath, + ) + ) + self.logH( + None, f"CrashManager update sinkpoint to sinkmanager from previously unmatched crash: {sink}" + ) + await self.crs.sinkmanager.on_event_update_sinkpoint(sink, source="crashmanager") + else: + still_unmatched.append( + (hrunner, sanitizer, crash_msg, frames, dedup_token, artifact_name, artifact_abspath) + ) + self._unmatched_crashes = still_unmatched + async def _monitor_all_result_jsons(self): """Monitor all result.json files with a simple sequential loop.""" path_tuples = await self._collect_all_result_json_paths() @@ -372,6 +407,7 @@ async def _monitor_all_result_jsons(self): f"{CRS_ERR} processing submitted result.json: {e} {traceback.format_exc()}", ) + await self._retry_unmatched_crashes() await asyncio.sleep(1) except asyncio.CancelledError: diff --git a/crs/javacrs_modules/expkit.py b/crs/javacrs_modules/expkit.py index f48599bb5..a7a3fc9be 100755 --- a/crs/javacrs_modules/expkit.py +++ b/crs/javacrs_modules/expkit.py @@ -39,13 +39,9 @@ class ExpKitParams(BaseModel): 300, description="**Optional**, timeout in seconds for each beepseed exploitation. Default is 300 seconds.", ) - gen_models: str = Field( + gen_model: str = Field( ..., - description="**Mandatory**, comma-separated list of generation models. Format: 'model1:weight1,model2:weight2,...'. Example: 'o1-preview:10,claude-3-7-sonnet-20250219:20,none:5'", - ) - x_models: str = Field( - ..., - description="**Mandatory**, comma-separated list of extraction models. Format: 'model1:weight1,model2:weight2,...'. Example: 'gpt-4o:10,o3-mini:20,none:5'", + description="**Mandatory**, LLM model used by the exploitation agent.", ) debug_list_txt: Optional[str] = Field( None, @@ -68,21 +64,10 @@ def exp_time_should_be_positive(cls, v): raise ValueError("exp_time must be a positive integer") return v - @field_validator("gen_models", "x_models") - def models_should_be_valid(cls, v): - if not isinstance(v, str): - raise ValueError("gen_models and x_models must be strings") - models = v.split(",") - for model in models: - if ":" not in model: - raise ValueError( - "Invalid format for gen_models or x_models. Expected 'model:weight'" - ) - model_name, weight = model.split(":") - if not weight.isdigit() or int(weight) <= 0: - raise ValueError( - "Weight must be a positive integer in the format 'model:weight'" - ) + @field_validator("gen_model") + def gen_model_should_be_non_empty(cls, v): + if not isinstance(v, str) or not v.strip(): + raise ValueError("gen_model must be a non-empty string") return v @field_validator("debug_list_txt") @@ -272,8 +257,7 @@ def __init__( self.enabled = self.params.enabled self.monitor_interval = 0.1 self.exp_time = self.params.exp_time - self.gen_models = self.params.gen_models - self.x_models = self.params.x_models + self.gen_model = self.params.gen_model self.handled_beeps: Set[str] = set() # NOTE: This should be init in runtime, in async_run self.target_harnesses: List[str] | None = None @@ -397,7 +381,7 @@ async def process_beepseed(json_path: Path): None, f"Expkit update sinkpoint to sinkmanager from beepseed: {sink}", ) - await self.crs.sinkmanager.on_event_update_sinkpoint(sink) + await self.crs.sinkmanager.on_event_update_sinkpoint(sink, source="expkit") # Mark after it is successfully handled as race can happen when loading await self._mark_handled_path(json_path) @@ -489,7 +473,7 @@ async def _process_debug_beepseed(self, json_path: Path): None, f"Expkit update sinkpoint to sinkmanager from debug beepseed: {sink}", ) - await self.crs.sinkmanager.on_event_update_sinkpoint(sink) + await self.crs.sinkmanager.on_event_update_sinkpoint(sink, source="expkit") except (FileNotFoundError, JSONDecodeError) as e: self.logH( @@ -511,9 +495,14 @@ async def _gen_exploit_script( result_file: Path, log_file: Path, exp_time: int, + time_left: int, ) -> Path: """Generate shell script to execute expkit tool for a beepseed.""" command = [ + "timeout", + "-s", + "SIGKILL", + f"{str(time_left)}s", "taskset", "-c", str(cpu_id), @@ -526,10 +515,8 @@ async def _gen_exploit_script( str(self.crs.meta.meta_path.resolve()), "--exp-time", str(exp_time), - "--gen-models", - self.gen_models, - "--x-models", - self.x_models, + "--gen-model", + self.gen_model, "--workdir", str(output_dir.resolve()), "--verbose", @@ -694,6 +681,7 @@ async def _exploit_one_beepseed( result_file, log_file, actual_exp_time, + int(self.ttl_fuzz_time - elapsed_time), ) self.logH(None, f"Starting exploitation with timeout {actual_exp_time}s") diff --git a/crs/javacrs_modules/jazzer.py b/crs/javacrs_modules/jazzer.py index 6e093a4fa..5d1826f0c 100755 --- a/crs/javacrs_modules/jazzer.py +++ b/crs/javacrs_modules/jazzer.py @@ -64,6 +64,12 @@ class JazzerParams(BaseModel): 1048576, description="**Optional**, libfuzzer -max_len param. If unset, will be 1048576 (1M).", ) + fuzz_time: int = Field( + 0, description="**Optional**, max fuzzing time. Default: 0 (unlimited)." + ) + wait_for_llmpocgen: bool = Field( + False, description="**Optional**, whether to wait for llmpocgen to finish before starting to fuzz. Default: False." + ) @field_validator("enabled") def enabled_should_be_boolean(cls, v): @@ -103,6 +109,7 @@ def __init__( self.ttl_fuzz_time: int = self.crs.ttl_fuzz_time self.envs["FUZZ_TTL_FUZZ_TIME"] = str(self.ttl_fuzz_time) self.params = params + self.llm_poc_gen_done_event = asyncio.Event() self._init_from_params() @abstractmethod @@ -112,6 +119,13 @@ def _init_from_params(self): self.len_control = self.params.len_control self.max_len = self.params.max_len self.envs["FUZZ_KEEP_SEED"] = "on" if self.params.keep_seed else "off" + # If jazzer is configured with fuzz_time > 0, the actual fuzz_time is the min of + # the module's fuzz_time and the CRS-wide ttl_fuzz_time + self.fuzz_time = self.params.fuzz_time + if self.fuzz_time > 0: + self.ttl_fuzz_time = min(self.fuzz_time, self.ttl_fuzz_time) + self.envs["FUZZ_TTL_FUZZ_TIME"] = str(self.ttl_fuzz_time) + self.wait_for_llmpocgen = self.params.wait_for_llmpocgen def _init(self): pass @@ -160,6 +174,10 @@ async def add_corpus_file(self, hrunner: HarnessRunner, seed_file: Path): f"Added seed file {seed_file.resolve()} into corpus_dir {corpus_dir.resolve()} as {dst_file.resolve()}", ) + async def notify_llm_poc_gen_done(self): + if self.enabled and self.wait_for_llmpocgen: + self.llm_poc_gen_done_event.set() + async def get_expected_fuzz_instance_dirs( self, hrunner: HarnessRunner ) -> List[Path]: @@ -378,6 +396,14 @@ async def _async_run_instance( hrunner, fuzz_id, workdir ) + # Wait for llmpocgen to finish before starting to fuzz (if configured) + if self.wait_for_llmpocgen: + self.logH( + hrunner, + f"{fuzz_id} is waiting for llmpocgen to finish before starting to fuzz", + ) + await self.llm_poc_gen_done_event.wait() + env = await self._prepare_environment( hrunner, fuzz_id, @@ -596,7 +622,8 @@ async def _fuzzer_specific_env_setup( if self.crs.sinkmanager.enabled: sink_conf_file = self.crs.meta.get_custom_sink_conf_path() env["FUZZ_CUSTOM_SINK_CONF"] = str(sink_conf_file.resolve()) - # Set FUZZ_SSMODE environment variable + # Disable Jazzer's built-in hardcoded sinkpoint table when ssmode is on, + # so only sinks fed via FUZZ_CUSTOM_SINK_CONF get instrumented. env["FUZZ_SSMODE"] = "on" if self.crs.is_ssmode() else "off" # Set JACOCO_COV_DUMP_PERIOD environment variable if specified if self.params.jacoco_cov_dump_period is not None: diff --git a/crs/javacrs_modules/llmpocgen.py b/crs/javacrs_modules/llmpocgen.py index 05d842f90..67757eaad 100755 --- a/crs/javacrs_modules/llmpocgen.py +++ b/crs/javacrs_modules/llmpocgen.py @@ -39,6 +39,20 @@ class LLMPOCGeneratorParams(BaseModel): ..., description="**Mandatory**, mode of `llmpocgen` module, one of 'crs' or 'static' or 'onetime', static mode is for testing purpose.", ) + models: List[str] = Field( + default_factory=lambda: [ + "claude-opus-4-20250514", + "o3", + "claude-sonnet-4-20250514", + "gemini-2.5-pro", + "gpt-4.1", + ], + description="**Optional**, list of available LLM models. Tasks use their own preference order but are constrained to this set. Default: all supported models.", + ) + scan_sinks: bool = Field( + True, + description="**Optional**, whether llmpocgen should run its own Joern-based sink discovery (FROM_INSIDE). Default: True. Set to False to rely solely on sinks from the CRS sinkmanager.", + ) diff_max_len: int = Field( 65536, description="Maximum length for diff content processing, must be between 16K and 512K.", @@ -156,9 +170,6 @@ async def _gen_command_script( self.mode, ] - if self.crs.is_ssmode(): - command.append("--dev") - command.extend( [ "--cg", @@ -169,7 +180,10 @@ async def _gen_command_script( self.diff_max_len, "--worker", self.worker_num, + "--models", + ",".join(self.params.models), ] + + (["--scan-sinks"] if self.params.scan_sinks else []) ) command_str = " ".join(shlex.quote(str(arg)) for arg in command) # N.B. stdout & stderr are redirected to avoid python pipe OOM issues @@ -207,6 +221,9 @@ async def _async_gen_blackboard( # It is abnormal as long as we are not killed by SIGKILL self.logH(None, f"{CRS_ERR} llm-poc-gen unexpectedly exits with ret {ret}") + # Notify jazzer about end of llm-poc-gen + await self._notify_jazzer() + async def _async_parse_blackboard( self, blackboard_path: Path, @@ -385,7 +402,7 @@ async def _process_sinkpoints(self, sinkpoints, processed_sinkpoints): # Notify sink manager sink = Sinkpoint.frm_dict(sink_dict) self.logH(None, f"llmpocgen update sinkpoint to sinkmanager: {sink}") - await self.crs.sinkmanager.on_event_update_sinkpoint(sink) + await self.crs.sinkmanager.on_event_update_sinkpoint(sink, source="llmpocgen") except Exception as e: self.logH( @@ -461,6 +478,13 @@ async def _async_monitor_blackboard(self, workdir: Path, blackboard_path: Path): await asyncio.sleep(1) + async def _notify_jazzer(self): + """Notifies all enabled Jazzer modules that LLM POC generation is done.""" + for mod in self.crs.modules: + if is_fuzzing_module(mod) and mod.enabled: + self.logH(None, f"Notifying Jazzer module {mod.name} of llm-poc-gen completion.") + await mod.notify_llm_poc_gen_done() + async def _async_run(self, _): """Runs the LLMPOCGenerator module for a given CP.""" if not self.enabled: @@ -479,15 +503,21 @@ async def _async_run(self, _): blackboard_path = self.workdir / "blackboard" - results = await asyncio.gather( - self._async_gen_blackboard(self.workdir, cp_name, cpu_list), - self._async_monitor_blackboard(self.workdir, blackboard_path), - return_exceptions=True, + # Create monitor task + monitor_task = asyncio.create_task( + self._async_monitor_blackboard(self.workdir, blackboard_path) ) - for result in results: - if isinstance(result, Exception): - raise result + # Wait for gen_blackboard to finish + try: + await self._async_gen_blackboard(self.workdir, cp_name, cpu_list) + finally: + # Cancel monitor when gen_blackboard finishes + monitor_task.cancel() + try: + await monitor_task + except asyncio.CancelledError: + self.logH(None, "Blackboard monitor cancelled after gen_blackboard finished") except Exception as e: self.logH( diff --git a/crs/javacrs_modules/scripts/run-jazzer.sh b/crs/javacrs_modules/scripts/run-jazzer.sh index 306b8a85b..bd77210a1 100755 --- a/crs/javacrs_modules/scripts/run-jazzer.sh +++ b/crs/javacrs_modules/scripts/run-jazzer.sh @@ -319,6 +319,27 @@ fi # Kick off the fuzzer # cat > "${WORK_DIR}/_run_fuzzer_timeout_stub.sh" < M growth restarts are gated by COOLDOWN_SECONDS so each Jazzer instance +# gets at least that much productive run time before it can be killed. The +# 0 -> N case (first sinks ever arriving) bypasses the cooldown. +POLL_INTERVAL=5 +GROWTH_PCT=10 +COOLDOWN_SECONDS=300 + +# Recursively signal a process and all its descendants. We walk the tree +# ourselves instead of using process groups because run_fuzzer stays in the +# stub's process group, so the outer 'timeout -s SIGKILL' at campaign end +# continues to clean up the entire subtree in one shot. +kill_tree() { + local pid=\$1 sig=\$2 + local c + for c in \$(pgrep -P "\$pid" 2>/dev/null); do + kill_tree "\$c" "\$sig" + done + kill -"\$sig" "\$pid" 2>/dev/null || true +} + while true do @@ -336,12 +357,64 @@ do fi fi - timeout -s SIGKILL 900s \ - stdbuf -e 0 -o 0 \ - run_fuzzer ${FUZZ_TARGET_HARNESS} \ - --agent_path=\${JAZZER_DIR}/jazzer_standalone_deploy.jar \ - \${ATL_OPTIONS} \ - "\$@" || echo @@@@@ exit code of Jazzer is $? @@@@@ >&2 + # Snapshot the sink count at iteration start. One line = one sink + # (caller#... or api#...), and sinks are monotonic, so line count is + # a direct measure of "did the sink pool grow". + baseline_lines=\$(wc -l < "\$ATLJAZZER_CUSTOM_SINKPOINT_CONF" 2>/dev/null || echo 0) + iter_start_ts=\$(date +%s) + echo "SINK_RESTART: starting run_fuzzer (baseline_lines=\$baseline_lines, growth_pct=\${GROWTH_PCT}, cooldown=\${COOLDOWN_SECONDS}s)" + + stdbuf -e 0 -o 0 \ + run_fuzzer ${FUZZ_TARGET_HARNESS} \ + --agent_path=\${JAZZER_DIR}/jazzer_standalone_deploy.jar \ + \${ATL_OPTIONS} \ + "\$@" & + RF_PID=\$! + + # Watch for sink-conf growth while run_fuzzer runs. + should_restart=0 + while kill -0 \$RF_PID 2>/dev/null; do + sleep \$POLL_INTERVAL + now=\$(date +%s) + cur_lines=\$(wc -l < "\$ATLJAZZER_CUSTOM_SINKPOINT_CONF" 2>/dev/null || echo 0) + + # 0 -> N always triggers (first sinks ever arriving). No cooldown so + # Jazzer stops wasting time fuzzing with an empty sink set. + if [[ \$baseline_lines -eq 0 && \$cur_lines -gt 0 ]]; then + echo "SINK_RESTART: 0 -> \$cur_lines, requesting restart" + should_restart=1 + break + fi + + # N -> M triggers when growth >= GROWTH_PCT, but only after COOLDOWN_SECONDS + # have elapsed since this iteration started — gives each Jazzer instance a + # minimum productive run time and prevents restart thrashing on small pools + # or trickling sink sources. + if [[ \$baseline_lines -gt 0 && \$((cur_lines * 100)) -ge \$((baseline_lines * (100 + GROWTH_PCT))) ]]; then + if [[ \$((now - iter_start_ts)) -ge \$COOLDOWN_SECONDS ]]; then + echo "SINK_RESTART: \$baseline_lines -> \$cur_lines (>= \${GROWTH_PCT}%), requesting restart" + should_restart=1 + break + fi + fi + done + + if [[ \$should_restart -eq 1 ]]; then + # Walk run_fuzzer's descendants (/out/, atl-jazzer, java) and + # signal each one. SIGTERM first so libfuzzer can flush corpus / jacoco / + # result.json; SIGKILL after 15s grace. + kill_tree \$RF_PID TERM + for _ in {1..15}; do + kill -0 \$RF_PID 2>/dev/null || break + sleep 1 + done + kill_tree \$RF_PID KILL + fi + wait \$RF_PID 2>/dev/null + RF_RET=\$? + if [[ \$RF_RET -ne 0 && \$should_restart -eq 0 ]]; then + echo "@@@@@ exit code of Jazzer is \$RF_RET @@@@@" >&2 + fi # Clean up! rm -rf ${DIRECTED_CLASS_DUMP_DIR} diff --git a/crs/javacrs_modules/sinkdetection.py b/crs/javacrs_modules/sinkdetection.py new file mode 100644 index 000000000..a41253384 --- /dev/null +++ b/crs/javacrs_modules/sinkdetection.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python3 +import asyncio +import json +import os +import shlex +import time +import traceback +from dataclasses import asdict +from pathlib import Path +from typing import Any, Dict, List + +import aiofiles +from libCRS import CRS, HarnessRunner, Module +from pydantic import BaseModel, Field, field_validator + +from .base_objs import Sinkpoint +from .utils import ( + CRS_ERR_LOG, + CRS_WARN_LOG, + get_env_exports, + get_env_or_abort, + run_process_and_capture_output, +) + +CRS_ERR = CRS_ERR_LOG("sinkdetection-mod") +CRS_WARN = CRS_WARN_LOG("sinkdetection-mod") + + +CWE_TO_MARK_DESC = { + "CWE-022": "sink-FilePathTraversal", + "CWE-078": "sink-OsCommandInjection", + "CWE-089": "sink-SqlInjection", + "CWE-090": "sink-LdapInjection", + "CWE-094": "sink-ScriptEngineInjection", + "CWE-117": "sink-RemoteJNDILookup", + "CWE-470": "sink-OsCommandInjection", + "CWE-502": "sink-RemoteCodeExecution", + "CWE-611": "sink-ServerSideRequestForgery", + "CWE-643": "sink-XPathInjection", + "CWE-730": "sink-RegexInjection", + "CWE-918": "sink-ServerSideRequestForgery", +} + +ALL_CWES = sorted(CWE_TO_MARK_DESC.keys()) + + +class SinkDetectionParams(BaseModel): + enabled: bool = Field( + ..., description="**Mandatory**, true/false to enable or disable this module." + ) + cwes: List[str] = Field( + default_factory=lambda: list(ALL_CWES), + description="**Optional**, subset of supported CWEs to detect. Any CWE not in the 12 supported CWEs is rejected.", + ) + gen_model: str = Field( + "gpt-5", + description="**Optional**, LLM model used by the filtering agent.", + ) + max_iterations: int = Field( + 15, + description="**Optional**, max iterations for filtering agent.", + ) + max_workers: int = Field( + 10, + description="**Optional**, max parallel workers for filtering agent.", + ) + + @field_validator("enabled") + def enabled_should_be_boolean(cls, v): + if not isinstance(v, bool): + raise ValueError("enabled must be a boolean") + return v + + @field_validator("cwes") + def cwes_must_be_supported(cls, v): + invalid = [c for c in v if c not in ALL_CWES] + if invalid: + raise ValueError( + f"Unsupported CWEs {invalid}. Supported: {ALL_CWES}" + ) + return v + + +class SinkDetection(Module): + def __init__( + self, + name: str, + crs: CRS, + params: SinkDetectionParams, + run_per_harness: bool, + ): + super().__init__(name, crs, run_per_harness) + self.params = params + self.enabled = self.params.enabled + # Resolve the CWE list once: empty means "all 12 supported CWEs". + self.cwes: List[str] = list(self.params.cwes) if self.params.cwes else list(ALL_CWES) + self.workdir = self.get_workdir("") / self.crs.cp.name + self.workdir.mkdir(parents=True, exist_ok=True) + self.db = Path("/out/crs/codeql-db") + self.results_file = self.workdir / "codeql_result.json" + self.reachable_file = self.workdir / "reachable_sinks.json" + self.results_file.parent.mkdir(parents=True, exist_ok=True) + self.java_crs_src = Path(get_env_or_abort("JAVA_CRS_SRC")) + self.tool_cwd = self.java_crs_src / "codeql" + self.filtering_agent_dir = self.java_crs_src / "filtering-agent" + + def _init(self): + pass + + async def _async_prepare(self): + pass + + async def _async_test(self, hrunner: HarnessRunner): + pass + + async def _async_get_mock_result(self, hrunner): + pass + + async def _create_command_sh( + self, + cmd: List[str], + script_name: str, + working_dir: str, + timeout: int, + cpu_list: List[int], + buffer_output: bool, + log_prefix: str, + ) -> Path: + if timeout: + cmd = ["timeout", "-s", "SIGKILL", f"{timeout}s"] + cmd + if cpu_list: + cpu_str = ",".join(map(str, cpu_list)) + cmd = ["taskset", "-c", cpu_str] + cmd + if buffer_output: + cmd = ["stdbuf", "-e", "0", "-o", "0"] + cmd + + cmd_str = " ".join(shlex.quote(str(arg)) for arg in cmd) + command_sh_content = f"""#!/bin/bash +# Env +{get_env_exports(os.environ)} +# Cmd +{f'cd "{working_dir}"' if working_dir else ''} +{cmd_str} > {str(self.workdir.resolve())}/{log_prefix}.log 2>&1 +""" + command_sh = self.workdir / script_name + async with aiofiles.open(command_sh, "w") as f_sh: + await f_sh.write(command_sh_content) + command_sh.chmod(0o755) + return command_sh + + async def _run_codeql_query(self, cpu_list: List[int]): + """Run CodeQL queries via run.sh.""" + self.logH(None, f"Running CodeQL queries with CPUs: {cpu_list}") + + # Always pass the resolved CWE list (guaranteed non-empty subset of the 12). + os.environ["CODEQL_CWES"] = ",".join(self.cwes) + self.logH(None, f"CodeQL CWEs: {self.cwes}") + + try: + rest_time = self.crs.rest_time() + if rest_time <= 0: + self.logH(None, f"{CRS_WARN} No time left to run CodeQL queries") + return + + script_path = self.tool_cwd / "run.sh" + if not script_path.exists(): + raise FileNotFoundError(f"CodeQL script not found: {script_path}") + + cmd_args = [ + str(script_path), + str(self.db), + str(self.results_file), + ] + + command_sh = await self._create_command_sh( + cmd_args, + "command_run.sh", + str(self.tool_cwd), + timeout=rest_time, + cpu_list=cpu_list, + buffer_output=True, + log_prefix="query", + ) + + self.logH(None, f"Running CodeQL command: {command_sh}") + log_file = self.workdir / "run.log" + + ret = await run_process_and_capture_output(command_sh, log_file) + if ret == 0: + self.logH( + None, + f"CodeQL queries finished (ret: {ret}) in {time.time() - self.crs.start_time:.2f}s", + ) + else: + self.logH( + None, + f"{CRS_ERR} CodeQL queries exited with ret {ret} in {time.time() - self.crs.start_time:.2f}s", + ) + + except Exception as e: + self.logH( + None, + f"{CRS_ERR} CodeQL query error: {str(e)} {traceback.format_exc()}", + ) + + async def _run_reachability_filter(self, call_graph_path: Path) -> Path: + """Annotate CodeQL results with reachability and drop filtered_out sinks. + + Writes to self.reachable_file without modifying the raw CodeQL results. + Returns the path of the file downstream stages should consume + (reachable_file on success, raw results_file on failure). + """ + script_path = self.tool_cwd / "find_reachable_sinks.py" + if not script_path.exists(): + self.logH(None, f"{CRS_WARN} find_reachable_sinks.py not found, skipping reachability filter") + return self.results_file + + self.logH(None, f"Running reachability filter with call graph: {call_graph_path}") + cmd = [ + "python3", str(script_path), + str(self.results_file), + str(call_graph_path), + "--output", str(self.reachable_file), + ] + command_str = " ".join(shlex.quote(str(arg)) for arg in cmd) + command_sh_content = f"""#!/bin/bash +# Env +{get_env_exports(os.environ)} +# Cmd +cd "{self.tool_cwd}" +{command_str} > "{self.workdir / 'reachability.log'}" 2>&1 +""" + command_sh = self.workdir / "reachability.sh" + async with aiofiles.open(command_sh, "w") as f: + await f.write(command_sh_content) + command_sh.chmod(0o755) + + ret = await run_process_and_capture_output( + command_sh, self.workdir / "reachability_run.log" + ) + if ret == 0 and self.reachable_file.exists(): + self.logH(None, f"Reachability filter completed, output: {self.reachable_file}") + return self.reachable_file + + self.logH(None, f"{CRS_WARN} Reachability filter exited with ret {ret}, using raw CodeQL results") + return self.results_file + + def _build_harness_cwe_pairs(self) -> str: + """Build harness:CWE pairs string from configured harnesses and CWEs.""" + harnesses = self.crs.get_target_harnesses() + pairs = [f"{h}:{c}" for h in harnesses for c in self.cwes] + return ",".join(pairs) + + def _get_empty_call_graph_path(self) -> Path: + """Create and return path to an empty call graph placeholder.""" + empty_cg = self.workdir / "empty-cg.json" + if not empty_cg.exists(): + empty_cg.write_text('{"nodes": [], "links": []}') + return empty_cg + + async def _wait_for_call_graph(self) -> Path: + """Wait for a call graph to become available. + + If llmpocgen is enabled, waits for its Joern CG. Otherwise checks + static analysis and SARIF listener CGs. Falls back to empty CG + if nothing becomes available before CRS time runs out. + """ + # Determine which CG to wait for based on enabled modules + primary_wait = None + if hasattr(self.crs, 'llmpocgen') and self.crs.llmpocgen.enabled: + primary_wait = self.crs.llmpocgen.joern_cg_file + self.logH(None, "Waiting for llmpocgen Joern call graph...") + + # All candidates to check while waiting + candidates = [] + if hasattr(self.crs, 'staticanalysis') and hasattr(self.crs.staticanalysis, 'soot_cg_file'): + candidates.append(self.crs.staticanalysis.soot_cg_file) + if hasattr(self.crs, 'sariflistener') and hasattr(self.crs.sariflistener, 'full_cg_file'): + candidates.append(self.crs.sariflistener.full_cg_file) + if hasattr(self.crs, 'llmpocgen') and hasattr(self.crs.llmpocgen, 'joern_cg_file'): + candidates.append(self.crs.llmpocgen.joern_cg_file) + + while self.crs.should_continue(): + # If we have a primary wait target, check it first + if primary_wait and primary_wait.exists(): + self.logH(None, f"Using call graph: {primary_wait}") + return primary_wait + + # Check all candidates + for cg in candidates: + if cg.exists(): + self.logH(None, f"Using call graph: {cg}") + return cg + + await asyncio.sleep(5) + + # Time ran out — use empty placeholder + self.logH(None, f"{CRS_WARN} No call graph available before timeout, using empty placeholder") + return self._get_empty_call_graph_path() + + async def _run_filtering_agent( + self, results_json: Path, cpu_list: List[int], call_graph_path: Path + ) -> Path: + """Run pick_sinks.py to assess exploitability. Returns path to results to use.""" + assessed_output = self.workdir / "assessed_sinks.json" + pick_sinks_script = self.filtering_agent_dir / "sinkpicker" / "pick_sinks.py" + + if not pick_sinks_script.exists(): + self.logH(None, f"{CRS_WARN} pick_sinks.py not found at {pick_sinks_script}, using raw results") + return results_json + + harness_cwe_pairs = self._build_harness_cwe_pairs() + self.logH(None, f"Filtering agent harness-CWE pairs: {harness_cwe_pairs}") + self.logH(None, f"Filtering agent call graph: {call_graph_path}") + + command = [ + "timeout", "-s", "SIGKILL", f"{self.crs.rest_time()}s", + "taskset", "-c", ",".join(map(str, cpu_list)), + "python3.12", + "-m", "sinkpicker.pick_sinks", + str(results_json), + str(assessed_output), + "--metadata", str(self.crs.meta.meta_path.resolve()), + "--harness-cwe-pairs", harness_cwe_pairs, + "--call-graph", str(call_graph_path), + "--gen-model", self.params.gen_model, + "--max-workers", str(self.params.max_workers), + "--max-iterations", str(self.params.max_iterations), + ] + + command_str = " ".join(shlex.quote(str(arg)) for arg in command) + command_sh_content = f"""#!/bin/bash +# Env +{get_env_exports(os.environ)} +# Cmd +cd "{self.filtering_agent_dir}" +{command_str} > "{self.workdir / 'pick_sinks.log'}" 2>&1 +""" + command_sh = self.workdir / "pick_sinks.sh" + async with aiofiles.open(command_sh, "w") as f: + await f.write(command_sh_content) + command_sh.chmod(0o755) + + self.logH(None, "Running filtering agent...") + ret = await run_process_and_capture_output( + command_sh, self.workdir / "pick_sinks_run.log" + ) + + if ret == 0 and assessed_output.exists(): + self.logH(None, "Filtering agent completed successfully") + return assessed_output + else: + self.logH( + None, + f"{CRS_WARN} Filtering agent exited with ret {ret}, falling back to raw CodeQL results", + ) + return results_json + + async def _update_sink(self, sink_dict: dict) -> bool: + """Resolve bytecode and feed a single sink to sinkmanager.""" + try: + cwe = sink_dict.get("cwe") + mark_desc = CWE_TO_MARK_DESC.get(cwe) + if mark_desc is None: + self.logH( + None, + f"{CRS_WARN} Skipping sinkpoint with unknown CWE '{cwe}' (no mark_desc mapping)", + ) + return False + + # Honor the filtering agent's verdict: keep the sink only if at + # least one harness in `in_final_result` is a harness the CRS is + # actually targeting in this run. + in_final_result = sink_dict.get("in_final_result") or [] + target_harnesses = set(self.crs.get_target_harnesses()) + matched_harnesses = [h for h in in_final_result if h in target_harnesses] + if not matched_harnesses: + self.logH( + None, + f"{CRS_WARN} Skipping sinkpoint (cwe={cwe}): agent in_final_result={in_final_result} " + f"has no overlap with targeted harnesses {sorted(target_harnesses)}", + ) + return False + + coord_dict = sink_dict["coord"] + code_coord = self.crs.query_code_coord( + coord_dict["class_name"], coord_dict["line_num"] + ) + if code_coord is None: + self.logH( + None, + f"{CRS_WARN} Filter out sinkpoint {coord_dict['class_name']}:{coord_dict['line_num']} which has no code coordinate", + ) + return False + + self.logH( + None, + f"Sinkpoint {coord_dict['class_name']}:{coord_dict['line_num']} found code coordinate: {code_coord}", + ) + coord_dict.update(asdict(code_coord)) + coord_dict["mark_desc"] = mark_desc + sink = Sinkpoint.frm_dict(sink_dict) + self.logH( + None, + f"Feeding sinkpoint to sinkmanager (matched harnesses={matched_harnesses}): {sink}", + ) + await self.crs.sinkmanager.on_event_update_sinkpoint(sink, source="sinkdetection") + return True + except Exception as e: + self.logH( + None, + f"{CRS_ERR} updating sink {sink_dict} to sinkmanager: {str(e)} {traceback.format_exc()}", + ) + return False + + async def _feed_sinks_to_sinkmanager(self, results_path: Path): + """Parse sink results and feed each to sinkmanager.""" + self.logH(None, f"Parsing results from {results_path}") + try: + if not results_path.exists(): + self.logH(None, f"{CRS_ERR} Results file not found: {results_path}") + return + + async with aiofiles.open(results_path, "r") as f: + results = await f.read() + + sinkpoints = json.loads(results) + if not isinstance(sinkpoints, list): + self.logH(None, f"{CRS_ERR} Invalid results format: expected list") + return + + total_sinks = len(sinkpoints) + self.logH(None, f"Found {total_sinks} sinkpoints in results") + + kept_sinks = 0 + for sink_dict in sinkpoints: + if await self._update_sink(sink_dict): + kept_sinks += 1 + + filtered_sinks = total_sinks - kept_sinks + self.logH( + None, f"Sinkpoints stats: {filtered_sinks} filtered, {kept_sinks} kept" + ) + + except Exception as e: + self.logH( + None, + f"{CRS_ERR} Result parsing error: {str(e)} {traceback.format_exc()}", + ) + + async def _async_run(self, _) -> Dict[str, Any]: + if not self.enabled: + self.logH(None, f"Module {self.name} is disabled") + return + + self.logH(None, f"Starting {self.name}") + + try: + cpu_list = await self.crs.cpuallocator.poll_allocation(None, self.name) + self.logH(None, f"Allocated CPUs: {cpu_list}") + + if not self.db.exists(): + self.logH(None, f"{CRS_ERR} CodeQL database not found at {self.db}, skipping") + return + + # 1. Run CodeQL queries + self.logH(None, f"Using CodeQL database at {self.db}") + await self._run_codeql_query(cpu_list) + + if not self.results_file.exists(): + self.logH(None, f"{CRS_ERR} CodeQL produced no results, skipping") + return + + # 2. Wait for call graph and run reachability filter + call_graph_path = await self._wait_for_call_graph() + reachable_sinks_file = await self._run_reachability_filter(call_graph_path) + + # 3. Run filtering agent + assessed_file = await self._run_filtering_agent(reachable_sinks_file, cpu_list, call_graph_path) + + # 4. Feed filtered sinks to sinkmanager + await self._feed_sinks_to_sinkmanager(assessed_file) + + except Exception as e: + self.logH( + None, f"{CRS_ERR} Module failed: {str(e)} {traceback.format_exc()}" + ) + + self.logH(None, f"{self.name} ended") diff --git a/crs/javacrs_modules/sinkmanager.py b/crs/javacrs_modules/sinkmanager.py index 72715c917..c28cd4ae9 100755 --- a/crs/javacrs_modules/sinkmanager.py +++ b/crs/javacrs_modules/sinkmanager.py @@ -36,6 +36,17 @@ class SinkManagerParams(BaseModel): enabled: bool = Field( ..., description="**Mandatory**, true/false to enable/disable this module." ) + allowed_sink_contributors: List[str] = Field( + default_factory=list, + description=( + "**Optional**, list of module names that are allowed to add NEW sinks " + "to the pool. Modules not in this list can still update metadata of " + "existing sinks but cannot introduce new ones. Empty list means all " + "modules are allowed (backward compat). Note: Redis-synced sinks count " + "as 'sinkmanager' source — include 'sinkmanager' here if you use " + "multi-pod fuzzing with a non-empty allowlist." + ), + ) @field_validator("enabled") def enabled_should_be_boolean(cls, v): @@ -73,6 +84,7 @@ def __init__( super().__init__(name, crs, run_per_harness) self.params = params self.enabled = self.params.enabled + self.allowed_sink_contributors = set(self.params.allowed_sink_contributors) self.ttl_fuzz_time: int = self.crs.ttl_fuzz_time self.workdir = self.get_workdir("") / self.crs.cp.name self.workdir.mkdir(parents=True, exist_ok=True) @@ -205,7 +217,7 @@ async def _pull_sinkpoint_from_redis(self, sp_key: str): sinkpoint = await self._get_sink_from_redis(sp_key) if sinkpoint is None: return - await self.on_event_update_sinkpoint(sinkpoint) + await self.on_event_update_sinkpoint(sinkpoint, source="sinkmanager") async def _push_sinkpoint_to_redis( self, sp_key: str, remote_hash: Optional[str] @@ -354,7 +366,9 @@ async def _dump_to_sink_conf_path(self): async with self._lock: confs = [sink.coord.to_conf() for sink in self.sinkpoints.values()] confs = [conf for conf in confs if conf] - conf_ctnt = "\n".join(confs) + # Each line ends with '\n', including the last one. POSIX-correct + # text file; lets readers like `wc -l` count sinks accurately. + conf_ctnt = "".join(f"{c}\n" for c in confs) await atomic_write_file(self.sink_conf_path, conf_ctnt) self.logH(None, f"Dumped sinkpoints to {self.sink_conf_path}") except Exception as e: @@ -383,15 +397,32 @@ async def _sync_sinkpoints_to_fs(self): await asyncio.sleep(1) self.logH(None, "Reached end time, exiting _sync_sinkpoints_to_fs") - async def _update_sink(self, sink: Sinkpoint) -> List[SinkUpdateEvent]: + def _is_contributor_allowed(self, source: str) -> bool: + """Check if a source module is allowed to add new sinks.""" + # Empty allowlist means all sources allowed (backward compat) + if not self.allowed_sink_contributors: + return True + return source in self.allowed_sink_contributors + + async def _update_sink( + self, sink: Sinkpoint, source: str + ) -> List[SinkUpdateEvent]: async with self._lock: updated_sink = None # Update sink obj if sink.coord not in self.sinkpoints: + # New sink: enforce contributor allowlist + if not self._is_contributor_allowed(source): + self.logH( + None, + f"Dropping new sink {sink.coord} from '{source}' (not in allowlist {sorted(self.allowed_sink_contributors)})", + ) + return [] self.sinkpoints[sink.coord] = sink updated_sink = sink else: + # Existing sink: metadata updates always allowed if self.sinkpoints[sink.coord].merge(sink): updated_sink = self.sinkpoints[sink.coord] @@ -586,14 +617,21 @@ async def on_event_sarif_challenge_solved(self, sarif_id: UUID): f"{CRS_ERR} in on_event_sarif_challenge_solved: {str(e)} {traceback.format_exc()}", ) - async def on_event_update_sinkpoint(self, sink: Sinkpoint): - """Handle new sinkpoint.""" + async def on_event_update_sinkpoint(self, sink: Sinkpoint, source: str): + """Handle new sinkpoint. + + Args: + sink: The Sinkpoint to add or merge. + source: Name of the calling module (e.g. 'sinkdetection', 'crashmanager'). + Used to enforce the `allowed_sink_contributors` allowlist for new + sinks. Existing sinks always accept metadata updates regardless. + """ if not self.enabled: self.logH(None, f"Module {self.name} is disabled, skip sink update event") return try: - events = await self._update_sink(sink) + events = await self._update_sink(sink, source) await self._notify_sink_update_events("SINK-UPDATE-EVENT", events) except Exception as e: self.logH( diff --git a/crs/javacrs_modules/staticanalysis.py b/crs/javacrs_modules/staticanalysis.py index 9c4913e50..de87af611 100755 --- a/crs/javacrs_modules/staticanalysis.py +++ b/crs/javacrs_modules/staticanalysis.py @@ -253,7 +253,7 @@ async def _update_target(self, target: DirectedFuzzTarget) -> None: try: sink = Sinkpoint.frm_dict(target.get_target_location()) self.logH(None, f"Static analysis update sinkpoint to sinkmanager: {sink}") - await self.crs.sinkmanager.on_event_update_sinkpoint(sink) + await self.crs.sinkmanager.on_event_update_sinkpoint(sink, source="staticanalysis") except Exception as e: self.logH( None, @@ -328,8 +328,6 @@ async def _run_static_analysis( str(self.static_ana_jar.resolve()), "--config", str(config_file.resolve()), - "--target-file", - str(self.crs.meta.sink_target_conf.resolve()), "--distance-map-file", str(self.static_ana_result.resolve()), "--cg-stages", diff --git a/crs/javacrscfg.py b/crs/javacrscfg.py index 8614afe45..51365e8ff 100755 --- a/crs/javacrscfg.py +++ b/crs/javacrscfg.py @@ -9,7 +9,7 @@ AtlDirectedJazzerParams, AtlJazzerParams, AtlLibAFLJazzerParams, - CodeQLParams, + SinkDetectionParams, ConcolicExecutorParams, CPUAllocatorParams, CrashManagerParams, @@ -65,7 +65,7 @@ class ModuleParams(BaseModel): concolic: ConcolicExecutorParams = Field( ..., description="ConcolicExecutor module parameters." ) - codeql: CodeQLParams = Field(..., description="CodeQL module parameters.") + sinkdetection: SinkDetectionParams = Field(..., description="SinkDetection module parameters.") dictgen: DictgenParams = Field(..., description="Dictgen module parameters.") diff_scheduler: DiffSchedulerParams = Field( ..., description="DiffScheduler module parameters." @@ -112,8 +112,8 @@ class Config: description="**Optional**, if set, enable sync log to NFS right after e2e. Default is False.", ) ssmode: bool = Field( - False, - description="**Optional**, if set, enable SS mode (Specified Sink only mode) for the CRS. Default is False.", + True, + description="**Optional**, if set, only use sinks from sinkmanager. If False, also use all Jazzer static sinks. Default is True.", ) modules: ModuleParams = Field(..., description="Module parameters.") @@ -240,6 +240,16 @@ def update_cfg_from_env(conf: dict) -> dict: return conf +def deep_merge(base: dict, override: dict) -> dict: + """Recursively merge override into base, preserving existing keys at all levels.""" + for key, value in override.items(): + if key in base and isinstance(base[key], dict) and isinstance(value, dict): + deep_merge(base[key], value) + else: + base[key] = value + return base + + def update_remote_specified_cfg(local_conf_file: str): with open(local_conf_file) as f: local_conf = json.load(f) @@ -250,7 +260,7 @@ def update_remote_specified_cfg(local_conf_file: str): print(f"Merging custom configuration from JAVACRS_CFG: {javacrs_cfg}") with open(javacrs_cfg) as f: remote_conf = json.load(f) - local_conf.update(remote_conf) + deep_merge(local_conf, remote_conf) elif javacrs_cfg is not None: print(f"Custom config not found at: {javacrs_cfg}, using default configuration") diff --git a/crs/libs/coordinates/bytecode-parser/pom.xml b/crs/libs/coordinates/bytecode-parser/pom.xml index 1a5f58a23..7d7ffc245 100644 --- a/crs/libs/coordinates/bytecode-parser/pom.xml +++ b/crs/libs/coordinates/bytecode-parser/pom.xml @@ -9,16 +9,17 @@ 1.0-SNAPSHOT - + org.ow2.asm asm - 9.7 + 9.8-atlantis org.ow2.asm asm-tree - 9.7 + 9.8-atlantis diff --git a/crs/libs/coordinates/bytecode-parser/src/main/java/BytecodeInspector.java b/crs/libs/coordinates/bytecode-parser/src/main/java/BytecodeInspector.java index c4718b816..f2ef24288 100644 --- a/crs/libs/coordinates/bytecode-parser/src/main/java/BytecodeInspector.java +++ b/crs/libs/coordinates/bytecode-parser/src/main/java/BytecodeInspector.java @@ -128,9 +128,10 @@ static void processClassFile(Path classPath) throws IOException { } static void processClass(InputStream is, String jarFilePath, String classFilePath) throws IOException { - ClassReader cr = new ClassReader(is); + byte[] bytecode = is.readAllBytes(); + ClassReader cr = new ClassReader(bytecode); ClassNode classNode = new ClassNode(); - cr.accept(classNode, ClassReader.SKIP_FRAMES); + cr.accept(classNode, 0); String className = classNode.name.replace('/', '.'); if (!isWhitelisted(className)) { @@ -140,28 +141,36 @@ static void processClass(InputStream is, String jarFilePath, String classFilePat String sourceFileName = classNode.sourceFile != null ? classNode.sourceFile : "Unknown"; for (MethodNode method : classNode.methods) { + int curLine = -1; + Set recordedLines = new HashSet<>(); + for (AbstractInsnNode insn : method.instructions) { if (insn instanceof LineNumberNode) { - LineNumberNode lineNode = (LineNumberNode) insn; - int offset = method.instructions.indexOf(lineNode.start); - int line = lineNode.line; - - CodeLocation loc = new CodeLocation( - jarFilePath, - classFilePath, - className, - sourceFileName, - method.name, - method.desc, - offset, - line - ); - - // Use className as the outer key instead of sourceFileName - index.computeIfAbsent(className, k -> new HashMap<>()) - .computeIfAbsent(line, k -> new ArrayList<>()) - .add(loc); + curLine = ((LineNumberNode) insn).line; + continue; + } + if (insn.getOpcode() < 0 || curLine < 0) { + continue; + } + if (recordedLines.contains(curLine)) { + continue; } + recordedLines.add(curLine); + + CodeLocation loc = new CodeLocation( + jarFilePath, + classFilePath, + className, + sourceFileName, + method.name, + method.desc, + insn.getBytecodeOffset(), + curLine + ); + + index.computeIfAbsent(className, k -> new HashMap<>()) + .computeIfAbsent(curLine, k -> new ArrayList<>()) + .add(loc); } } } diff --git a/crs/llm-poc-gen/tests/test_model_manager.py b/crs/llm-poc-gen/tests/test_model_manager.py index 4317ee85f..7de5eb481 100644 --- a/crs/llm-poc-gen/tests/test_model_manager.py +++ b/crs/llm-poc-gen/tests/test_model_manager.py @@ -237,6 +237,40 @@ async def test_invoke_non_register_model(): await ModelManager().invoke_atomic([], model_name, None) +@pytest.mark.asyncio +async def test_invoke_unknown_name_with_other_registered(): + """Strict dispatch: unknown model raises even if others are registered.""" + await ModelManager().add_model( + lambda input, output: input + output, "mock", MockChatModel() + ) + with pytest.raises(RuntimeError): + await ModelManager().invoke_atomic([], "fake", None) + with pytest.raises(RuntimeError): + await ModelManager().invoke([], "fake", None) + + +@pytest.mark.asyncio +async def test_resolve_model_fallback(): + await ModelManager().add_model( + lambda input, output: input + output, "mock", MockChatModel() + ) + assert ModelManager().resolve_model("mock") == "mock" + assert ModelManager().resolve_model("missing") == "mock" + + +@pytest.mark.asyncio +async def test_resolve_models_preserves_preference_order(): + await ModelManager().add_model( + lambda input, output: input + output, "a", MockChatModel() + ) + await ModelManager().add_model( + lambda input, output: input + output, "b", MockChatModel() + ) + assert ModelManager().resolve_models(["b", "missing", "a"]) == ["b", "a"] + # No preferences match → return all registered as fallback + assert sorted(ModelManager().resolve_models(["x", "y"])) == ["a", "b"] + + @pytest.mark.asyncio @patch.object(ModelManager, "_invoke_atomic") async def test_invoke_runtime_error_from__invoke_atomic(patch_1): diff --git a/crs/llm-poc-gen/vuli/agents/exploit.py b/crs/llm-poc-gen/vuli/agents/exploit.py index f78336fdc..d473be9bb 100644 --- a/crs/llm-poc-gen/vuli/agents/exploit.py +++ b/crs/llm-poc-gen/vuli/agents/exploit.py @@ -107,11 +107,9 @@ def need_regenerate(self, state: dict) -> bool: async def extend_code(self, state: dict) -> bool: self._logger.info("PoV Code Extension Start") state["extend"] = True - models: list[str] = [ - x - for x in ["gemini-2.5-pro", "o3", "claude-opus-4-20250514"] - if x in state["models"] - ] + models: list[str] = ModelManager().resolve_models( + ["gemini-2.5-pro", "o3", "claude-opus-4-20250514"] + ) if len(models) == 0: self._logger.info("No Model. Skip Expand for PoV Generation") return state @@ -154,7 +152,7 @@ async def extend_code(self, state: dict) -> bool: try: result: dict[str, dict] = { path: await ModelManager().invoke_atomic( - messages, model_name, JsonParser() + messages, model_name, JsonParser(), agent="exploit" ) } state["code_table"] = self.update_code_table( diff --git a/crs/llm-poc-gen/vuli/agents/extender.py b/crs/llm-poc-gen/vuli/agents/extender.py index fa036c79d..417b975a9 100644 --- a/crs/llm-poc-gen/vuli/agents/extender.py +++ b/crs/llm-poc-gen/vuli/agents/extender.py @@ -114,7 +114,7 @@ async def extend_per_file( ): try: return await ModelManager().invoke_atomic( - copied_messages, model_name, JsonParser() + copied_messages, model_name, JsonParser(), agent="extender" ) except LLMRetriable: await asyncio.sleep(60) diff --git a/crs/llm-poc-gen/vuli/agents/generator.py b/crs/llm-poc-gen/vuli/agents/generator.py index 78c556676..dc20d3e4a 100644 --- a/crs/llm-poc-gen/vuli/agents/generator.py +++ b/crs/llm-poc-gen/vuli/agents/generator.py @@ -124,9 +124,10 @@ async def localize(self, state: dict) -> dict: {code} ```""" ) + model_name: str = ModelManager().resolve_model("gpt-4.1") try: result: dict = await ModelManager().invoke( - [message], "gpt-4.1", JsonParser() + [message], model_name, JsonParser(), agent="generator" ) except Exception as e: self._logger.warning(f"Skip Exception: {e}") diff --git a/crs/llm-poc-gen/vuli/agents/pathvalidator.py b/crs/llm-poc-gen/vuli/agents/pathvalidator.py index 31fdb0fbe..500efdb38 100644 --- a/crs/llm-poc-gen/vuli/agents/pathvalidator.py +++ b/crs/llm-poc-gen/vuli/agents/pathvalidator.py @@ -107,7 +107,8 @@ async def validate(self, state: dict) -> dict: """ ) messages.append(message) - model_result: dict = await ModelManager().invoke(messages, "gpt-4.1", Parser()) + model_name: str = ModelManager().resolve_model("gpt-4.1") + model_result: dict = await ModelManager().invoke(messages, model_name, Parser(), agent="pathvalidator") cache: bool = model_result.get("cache", False) cost: float = model_result.get("cost", 0.0) result: bool = model_result.get("result", "OK") diff --git a/crs/llm-poc-gen/vuli/blobgen.py b/crs/llm-poc-gen/vuli/blobgen.py index 2b557a8fa..fe098129c 100644 --- a/crs/llm-poc-gen/vuli/blobgen.py +++ b/crs/llm-poc-gen/vuli/blobgen.py @@ -145,7 +145,7 @@ async def generate( ) -> BlobGeneratorResult: try: model_result: list[dict] = await ModelManager().invoke( - messages, model_name, parser + messages, model_name, parser, agent="blobgen" ) except Exception as e: self._logger.warning(f"Skip Exception: {e}") @@ -265,7 +265,7 @@ async def _create_blob_generator_result( """ try: model_output: list[dict] = await ModelManager().invoke_atomic( - messages, model_name, parser + messages, model_name, parser, agent="blobgen" ) except LLMParseException: self._logger.info( @@ -281,9 +281,10 @@ async def _create_blob_generator_result( ) except LLMRetriable as e: raise e - except Exception: - self._logger.info( - f"Blob Generation Failed [reason=LLM Failed, model={model_name}]" + except Exception as e: + self._logger.exception( + f"Blob Generation Failed [reason=LLM Failed, model={model_name}, " + f"exc={e.__class__.__name__}: {e}]" ) return None diff --git a/crs/llm-poc-gen/vuli/chatlog.py b/crs/llm-poc-gen/vuli/chatlog.py new file mode 100644 index 000000000..81a9b80ee --- /dev/null +++ b/crs/llm-poc-gen/vuli/chatlog.py @@ -0,0 +1,31 @@ +from pathlib import Path +from typing import Optional + +from langchain_core.messages import BaseMessage + +from vuli.common.singleton import Singleton + + +class ChatLog(metaclass=Singleton): + def __init__(self): + self._dir: Optional[Path] = None + self._counters: dict[str, int] = {} + + def initialize(self, output_dir: Path) -> None: + self._dir = output_dir / "llm-chats" + self._dir.mkdir(exist_ok=True) + + def log( + self, agent: str, model: str, messages: list[BaseMessage], response: BaseMessage + ) -> None: + if self._dir is None: + return + seq = self._counters.get(agent, 0) + 1 + self._counters[agent] = seq + path = self._dir / f"{agent}-{seq:03d}.log" + with open(path, "w") as f: + f.write(f"Model: {model}\n\n") + for msg in messages: + f.write(f"{msg.pretty_repr()}\n\n") + f.write("--- Response ---\n\n") + f.write(f"{response.pretty_repr()}\n") diff --git a/crs/llm-poc-gen/vuli/commandline.py b/crs/llm-poc-gen/vuli/commandline.py index 4a08f22d7..35c5fce16 100644 --- a/crs/llm-poc-gen/vuli/commandline.py +++ b/crs/llm-poc-gen/vuli/commandline.py @@ -23,6 +23,8 @@ class CommandLineOption(BaseModel): server_dir: Optional[Path] = None shared_dir: Optional[Path] = None diff_threashold: int + models: Optional[list[str]] = None + scan_sinks: bool = True class CommandLineOptionBuilder: @@ -108,6 +110,18 @@ def build(self) -> Optional[CommandLineOption]: default=-1, help="Threshold for an LLM to attempt diff file analysis", ) + parser.add_argument( + "--models", + type=str, + default=None, + help="Comma-separated list of available LLM model names. If not set, uses mode-specific defaults.", + ) + parser.add_argument( + "--scan-sinks", + action="store_true", + default=False, + help="Enable Joern-based sink discovery (FROM_INSIDE). If not set, relies solely on CRS sinkmanager sinks.", + ) args = parser.parse_args() cp_meta: Path = Path(args.cp_meta).absolute() @@ -166,4 +180,6 @@ def build(self) -> Optional[CommandLineOption]: server_dir=server_dir, shared_dir=shared, diff_threashold=args.diff_threashold, + models=[m.strip() for m in args.models.split(",")] if args.models else None, + scan_sinks=args.scan_sinks, ) diff --git a/crs/llm-poc-gen/vuli/delta.py b/crs/llm-poc-gen/vuli/delta.py index 9e98f7a9d..3f6e86f8c 100644 --- a/crs/llm-poc-gen/vuli/delta.py +++ b/crs/llm-poc-gen/vuli/delta.py @@ -247,6 +247,8 @@ def _create_hunk_messages(self, hunks: list[tuple[PatchedFile, Hunk]]) -> list[s return hunk_msgs async def _infer_sinks(self, hunk_messages: list[str]) -> list[dict]: + model_name: str = ModelManager().resolve_model("gpt-4.1") + async def _infer(msg: str) -> list[dict]: messages: list[BaseMessage] = [ HumanMessage( @@ -267,7 +269,7 @@ async def _infer(msg: str) -> list[dict]: for i in range(0, 3): try: return await ModelManager().invoke_atomic( - messages, "gpt-4.1", DeltaParser() + messages, model_name, DeltaParser(), agent="delta" ) except LLMRetriable: await asyncio.sleep(60) diff --git a/crs/llm-poc-gen/vuli/main.py b/crs/llm-poc-gen/vuli/main.py index 65c0b745d..df4029aaf 100644 --- a/crs/llm-poc-gen/vuli/main.py +++ b/crs/llm-poc-gen/vuli/main.py @@ -5,6 +5,7 @@ from typing import Optional from vuli.blackboard import Blackboard +from vuli.chatlog import ChatLog from vuli.commandline import CommandLineOption, CommandLineOptionBuilder from vuli.common.setting import Setting from vuli.cp import CP @@ -35,6 +36,7 @@ def initialize_system( ) -> None: root_dir: Path = Path(__file__).parent.parent.absolute() Setting().load(jazzer, joern_dir, output_dir, root_dir, dev, shared_dir) + ChatLog().initialize(Setting().output_dir) asyncio.run(Blackboard().set_path(Setting().blackboard_path)) CP().load(cp_meta, harnesses, cg_paths) CP()._server_dir = server_dir @@ -103,7 +105,8 @@ def main(): cmd_option.diff_threashold, ) runner: Optional[Runner] = create_runner( - cmd_option.mode, cmd_option.workers, cmd_option.model_cache + cmd_option.mode, cmd_option.workers, cmd_option.model_cache, + cmd_option.models, cmd_option.scan_sinks, ) if runner is None: logger.error(f"Not Found Runner [mode={cmd_option.mode}]") diff --git a/crs/llm-poc-gen/vuli/model_manager.py b/crs/llm-poc-gen/vuli/model_manager.py index a569cf425..96afd8bf0 100644 --- a/crs/llm-poc-gen/vuli/model_manager.py +++ b/crs/llm-poc-gen/vuli/model_manager.py @@ -31,6 +31,7 @@ wait_random, ) +from vuli.chatlog import ChatLog from vuli.common.decorators import SEVERITY, async_lock, async_safe from vuli.common.singleton import Singleton from vuli.struct import LLMParseException, LLMRetriable @@ -347,6 +348,32 @@ async def add_model( def get_all_model_names(self) -> list[str]: return sorted(list(self._models.keys())) + def resolve_model(self, preferred: str) -> str: + """Resolve a preferred model name against available models. + + Returns the preferred model if available, otherwise the first + registered model. Raises RuntimeError if no models are registered. + """ + if preferred in self._models: + return preferred + if self._models: + fallback = next(iter(self._models)) + self._logger.warning(f"Model '{preferred}' not available, falling back to '{fallback}'") + return fallback + raise RuntimeError(f"No models registered (requested '{preferred}')") + + def resolve_models(self, preferences: list[str]) -> list[str]: + """Filter a preference list against available models. + + Returns the subset of preferred models that are registered, preserving + order. If none match, returns all registered models. + """ + resolved = [m for m in preferences if m in self._models] + if resolved: + return resolved + # None of the preferred models are available — return all available + return list(self._models.keys()) + def get_total_usage(self) -> tuple[float, float]: total_cost: float = 0.0 total_saved: float = 0.0 @@ -390,12 +417,13 @@ async def invoke_atomic( messages: list[BaseMessage], model_name: str, parser: Optional[RunnableSequence], + agent: str = "", ) -> Any: """ Raises: RuntimeError, LLMRetriable """ if model_name not in self._models: - raise RuntimeError(f"Unregistered Model: {model_name}") + raise RuntimeError(f"Model '{model_name}' is not registered") metadata: ModelMetadata = self._models[model_name] if self._cache: @@ -408,6 +436,9 @@ async def invoke_atomic( except Exception as e: raise e + if agent: + ChatLog().log(agent, model_name, messages, message) + try: _, result, _ = await self._retry_parse( metadata.model, @@ -415,11 +446,18 @@ async def invoke_atomic( message.content, {"callbacks": [metadata.usage]}, 1, + agent=agent, ) except LLMRetriable as e: raise e - except Exception: - raise RuntimeError("LLM Output has unexpected format") + except Exception as e: + self._logger.exception( + f"LLM Output has unexpected format " + f"[model={model_name}, exc={e.__class__.__name__}: {e}]" + ) + raise RuntimeError( + f"LLM Output has unexpected format: {e.__class__.__name__}: {e}" + ) from e return result @async_lock("_lock") @@ -428,12 +466,13 @@ async def invoke( messages: list[BaseMessage], model_name: str, parser: Optional[RunnableSequence], + agent: str = "", ) -> Any: """ Raises: RuntimeError, LLMParseException, LLMRetriable, RuntimeError """ if model_name not in self._models: - raise RuntimeError(f"Unregistered Model: {model_name}") + raise RuntimeError(f"Model '{model_name}' is not registered") metadata: ModelMetadata = self._models[model_name] if self._cache: @@ -449,20 +488,24 @@ async def invoke( metadata.model, messages, {"callbacks": [metadata.usage]} ) self._logger.debug(f"LLM Response [{message.pretty_repr()}]") + if agent: + ChatLog().log(agent, model_name, messages, message) _, result, _ = await self._retry_parse( metadata.model, parser, message.content, {"callbacks": [metadata.usage]}, self._max_retries, + agent=agent, ) return result except LLMParseException as e: if i == self._max_retries: raise e except Exception as e: - self._logger.warning( - f"Skip Exception [case=while handling LLM answer, msg={e}]" + self._logger.exception( + f"Skip Exception [case=while handling LLM answer, " + f"model={model_name}, exc={e.__class__.__name__}: {e}]" ) raise RuntimeError("Unexpected State") @@ -508,7 +551,21 @@ async def _invoke_atomic( status_code: int = getattr(e, "status_code", 0) if status_code == 429 or status_code >= 500: raise LLMRetriable("") - raise RuntimeError("Failed to get response from LLM") + # Full traceback — re-runs are expensive, want everything first try. + self._logger.exception( + f"Failed to get response from LLM " + f"[model={runnable.model_name}, exc={e.__class__.__name__}: {e}]" + ) + if isinstance(e, APIStatusError): + body = getattr(e, "response", None) + body_text = getattr(body, "text", None) if body is not None else None + self._logger.warning( + f"APIStatusError detail " + f"[status={getattr(e, 'status_code', '?')}, body={body_text!r}]" + ) + raise RuntimeError( + f"Failed to get response from LLM: {e.__class__.__name__}: {e}" + ) from e async def _retry_parse( self, @@ -517,6 +574,7 @@ async def _retry_parse( completion: str, config: dict = {}, max_retries=1, + agent: str = "", ): """ Raises: LLMParseException, LLMRetriable, RuntimeException @@ -537,5 +595,9 @@ async def _retry_parse( message: BaseMessage = await self._invoke_atomic( runnable, messages, config ) + if agent: + ChatLog().log( + agent, runnable.model_name, messages, message + ) parse_content = message.content raise LLMParseException("Failed to parse") diff --git a/crs/llm-poc-gen/vuli/models.py b/crs/llm-poc-gen/vuli/models.py index 2f8c7d6df..14620428a 100644 --- a/crs/llm-poc-gen/vuli/models.py +++ b/crs/llm-poc-gen/vuli/models.py @@ -53,6 +53,51 @@ def name(self): return "claude-opus-4-20250514" +class Gemini25FlashLite(LLMModel): + def cost(self, input, output): + return input * 0.0000001 + output * 0.0000004 + + @property + def name(self): + return "gemini-2.5-flash-lite" + + +class ClaudeSonnet45(LLMModel): + def cost(self, input, output): + return input * 0.000003 + output * 0.000015 + + @property + def name(self): + return "claude-sonnet-4-5-20250929" + + +class GPT5(LLMModel): + def cost(self, input, output): + return input * 0.00000125 + output * 0.000010 + + @property + def name(self): + return "gpt-5" + + +class GPT5Mini(LLMModel): + def cost(self, input, output): + return input * 0.00000025 + output * 0.000002 + + @property + def name(self): + return "gpt-5-mini" + + +class GPT5Nano(LLMModel): + def cost(self, input, output): + return input * 0.00000005 + output * 0.0000004 + + @property + def name(self): + return "gpt-5-nano" + + class GPT41(LLMModel): def cost(self, input, output): return input * 0.000002 + output * 0.000008 @@ -71,14 +116,29 @@ def name(self): return "o3" +class GLM5(LLMModel): + def cost(self, input, output): + return input * 0.000001 + output * 0.0000032 + + @property + def name(self): + return "zai-org/GLM-5" + + def get_model(name: str) -> Optional[LLMModel]: models: dict[str, LLMModel] = { "gemini-2.5-pro": Gemini25Pro, + "gemini-2.5-flash-lite": Gemini25FlashLite, + "claude-sonnet-4-5-20250929": ClaudeSonnet45, "claude-sonnet-4-20250514": ClaudeSonnet4, "claude-opus-4-20250514": ClaudeOpus4, + "gpt-5": GPT5, + "gpt-5-mini": GPT5Mini, + "gpt-5-nano": GPT5Nano, "gpt-4.1": GPT41, "grok-3": Grok3, "o3": O3, + "glm5": GLM5, } if name in models: diff --git a/crs/llm-poc-gen/vuli/reflection.py b/crs/llm-poc-gen/vuli/reflection.py index 88b8ad695..8f5536ac6 100644 --- a/crs/llm-poc-gen/vuli/reflection.py +++ b/crs/llm-poc-gen/vuli/reflection.py @@ -294,6 +294,8 @@ async def __solve_for_cf(self, cf: list[int]) -> dict: if len(cf) == 0: return result + model_name: str = ModelManager().resolve_model("gpt-4.1") + system_message_1: BaseMessage = SystemMessage( content="""I want to know which functions can be called via reflection in the given code. Analyze the code and determine which class's methods can be invoked or which method names are callable. @@ -342,7 +344,7 @@ async def __solve_for_cf(self, cf: list[int]) -> dict: messages.append(HumanMessage(content=f"\n{code}")) try: model_result: dict = await ModelManager().invoke_atomic( - messages, "gpt-4.1", JsonParser() + messages, model_name, JsonParser(), agent="reflection" ) except Exception as e: self._logger.warning(f"Skip Exception: {e}") @@ -390,7 +392,7 @@ async def __solve_for_cf(self, cf: list[int]) -> dict: ) try: model_result: dict = await ModelManager().invoke( - messages, "gpt-4.1", JsonParser() + messages, model_name, JsonParser(), agent="reflection" ) except Exception as e: self._logger.warning(f"Skip Exception: {e}") diff --git a/crs/llm-poc-gen/vuli/runner.py b/crs/llm-poc-gen/vuli/runner.py index 10b83f2ca..9536844ec 100644 --- a/crs/llm-poc-gen/vuli/runner.py +++ b/crs/llm-poc-gen/vuli/runner.py @@ -130,13 +130,15 @@ async def _save_output(self, start_time: time.time) -> None: class StandAlone(Runner): - def __init__(self, workers: int = 1): + def __init__(self, workers: int = 1, scan_sinks: bool = True): super().__init__() self._logger = logging.getLogger(self.__class__.__name__) self._blobgen = BlobGeneration(CoverageBasedGeneration(), workers) + self._scan_sinks = scan_sinks async def _run(self) -> None: - await Scanner().run(CP().sanitizers) + if self._scan_sinks: + await Scanner().run(CP().sanitizers) await DeltaManager().handle() # await ReflectionSolver(CP().get_harnesses()).run() await FindPathService()._run() @@ -145,9 +147,10 @@ async def _run(self) -> None: class CRS(Runner): - def __init__(self, workers: int = 1): + def __init__(self, workers: int = 1, scan_sinks: bool = True): super().__init__() self._logger = logging.getLogger("CRS") + self._scan_sinks = scan_sinks sink_updater = SinkUpdateService() sink_updater.add_task(JavaCRS(CP()._sink_path)) TaskManager().add_handlers( @@ -179,7 +182,8 @@ def __init__(self, workers: int = 1): async def _run(self) -> None: TaskManager()._stop = False - await Scanner().run(CP().sanitizers) + if self._scan_sinks: + await Scanner().run(CP().sanitizers) await DeltaManager().handle() self._logger.info( @@ -242,7 +246,11 @@ async def _run(self): def create_runner( - mode: str, workers: int = 1, model_cache: Optional[Path] = None + mode: str, + workers: int = 1, + model_cache: Optional[Path] = None, + models: Optional[list[str]] = None, + scan_sinks: bool = True, ) -> Optional[Runner]: model_map = { "onetime": ["claude-sonnet-4-20250514", "o3", "gemini-2.5-pro", "gpt-4.1"], @@ -252,20 +260,23 @@ def create_runner( "default": [ "claude-opus-4-20250514", "o3", + "claude-sonnet-4-20250514", "gemini-2.5-pro", "gpt-4.1", ], } runner_map = { - "onetime": lambda: StandAlone(workers), + "onetime": lambda: StandAlone(workers, scan_sinks), "static": STATIC, "sink": SINK, "c_sarif": lambda: C_SARIF(workers), - "default": lambda: CRS(workers), + "default": lambda: CRS(workers, scan_sinks), } - models = model_map.get(mode, model_map["default"]) + if models is None: + models = model_map.get(mode, model_map["default"]) + runner_factory = runner_map.get(mode, runner_map["default"]) set_models(model_cache, models) diff --git a/crs/llm-poc-gen/vuli/task.py b/crs/llm-poc-gen/vuli/task.py index 9a6dc3195..f7dec85b2 100644 --- a/crs/llm-poc-gen/vuli/task.py +++ b/crs/llm-poc-gen/vuli/task.py @@ -146,8 +146,9 @@ async def run(self, path: VulInfo) -> path_manager.Status: ) task: dict = await self._reach(task) is_reached: bool = "error" not in task and task["reached"] is True + total_blobs: int = sum(task.get("history", {}).values()) self._logger.info( - f"Finish Generation Reachable Blob [reached={is_reached}, harness={path.harness_id}, sink={path.v_point}]" + f"Finish Generation Reachable Blob [reached={is_reached}, total_blobs={total_blobs}, harness={path.harness_id}, sink={path.v_point}]" ) if not is_reached: @@ -196,11 +197,11 @@ async def _reach(self, task: dict) -> dict: "candidate": task.get("candidate", None), "code_table": task.get("code_table", {}), "harness_id": task.get("harness_id", ""), - "models": [ + "models": ModelManager().resolve_models([ "o3", "gemini-2.5-pro", "claude-opus-4-20250514", - ], + ]), } result_state: dict = await graph.ainvoke(input_state, {"recursion_limit": 100}) necessary_keys: set[str] = {"code_table", "point", "prev", "reached"} @@ -214,6 +215,7 @@ async def _reach(self, task: dict) -> dict: task["point"] = result_state["point"] task["reached"] = result_state["reached"] task["prev"] = result_state["prev"] + task["history"] = result_state.get("history", {}) return task async def _exploit(self, task: dict) -> dict: @@ -239,11 +241,11 @@ async def _exploit(self, task: dict) -> dict: "prev": task["prev"], "point": task["point"], "sanitizer": sanitizer, - "models": [ + "models": ModelManager().resolve_models([ "o3", "gemini-2.5-pro", "claude-opus-4-20250514", - ], + ]), } result_state: dict = await graph.ainvoke(input_state) necessary_keys: set[str] = {"crash", "prev"} diff --git a/crs/main.py b/crs/main.py index 046156da9..b34c5735e 100755 --- a/crs/main.py +++ b/crs/main.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 import asyncio import logging -import os -import shutil import sys import time import traceback @@ -16,7 +14,7 @@ AtlDirectedJazzer, AtlJazzer, AtlLibAFLJazzer, - CodeQL, + SinkDetection, ConcolicExecutor, CPUAllocator, CrashManager, @@ -103,24 +101,6 @@ def __init__(self, name: str, hr_cls: type[HarnessRunner], conf_file: Path): else: self.inspector = None - # Handle ssmode sink file copy - if self.is_ssmode(): - java_crs_src = os.environ.get("JAVA_CRS_SRC") - if java_crs_src: - ssmode_sink_path = Path(java_crs_src) / "ssmode-sink.txt" - sink_targets_path = Path(java_crs_src) / "sink-targets.txt" - if ssmode_sink_path.exists(): - shutil.copy2(ssmode_sink_path, sink_targets_path) - self.log( - f"Copied {ssmode_sink_path} to {sink_targets_path} for ssmode" - ) - else: - self.log( - f"Warning: ssmode-sink.txt not found at {ssmode_sink_path}" - ) - else: - self.log("Warning: JAVA_CRS_SRC environment variable not set") - def _init_modules(self) -> List[Module]: module_list = [ # cp level modules @@ -135,7 +115,7 @@ def _init_modules(self) -> List[Module]: ("deepgen", DeepGenModule, False), ("dictgen", Dictgen, False), ("diff_scheduler", DiffScheduler, False), - ("codeql", CodeQL, False), + ("sinkdetection", SinkDetection, False), # per-harness modules ("concolic", ConcolicExecutor, True), ("aixccjazzer", AIxCCJazzer, True), @@ -244,7 +224,7 @@ async def main(self): self.deepgen.async_run(None), self.dictgen.async_run(None), self.diff_scheduler.async_run(None), - self.codeql.async_run(None), + self.sinkdetection.async_run(None), # libCRS entry func: inits harness runners and harness-level crs modules self.async_run(False), ] diff --git a/crs/run-crs-java.sh b/crs/run-crs-java.sh index c97778286..a9c35952b 100755 --- a/crs/run-crs-java.sh +++ b/crs/run-crs-java.sh @@ -79,7 +79,8 @@ run_crs() { update_crs_cfg - python3.12 -u ./main.py $DEFAULT_CFG 2>&1 | tee ./crs-java.log + mkdir -p $CRS_WORKDIR/worker-0 + python3.12 -u ./main.py $DEFAULT_CFG 2>&1 | tee $CRS_WORKDIR/worker-0/crs-java.log popd > /dev/null } diff --git a/crs/sink-targets.txt b/crs/sink-targets.txt deleted file mode 100644 index 9b2dd339e..000000000 --- a/crs/sink-targets.txt +++ /dev/null @@ -1,93 +0,0 @@ -######## -# Generated by crs/assets/sink-callee-apis/gen.py -# API-based sinkpoints (format: api#calleeClassName#methodName#methodDesc#markDesc) -# Coordinate-based sinkpoints (format: caller#className#methodName#methodDesc#fileName#lineNumber#bytecodeOffset#markDesc) -######## -api#java/lang/Class#forName#(Ljava/lang/String;)Ljava/lang/Class;#sink-UnsafeReflectiveCall -api#java/lang/Class#forName#(Ljava/lang/String;ZLjava/lang/ClassLoader;)Ljava/lang/Class;#sink-UnsafeReflectiveCall -api#java/lang/ClassLoader#loadClass#(Ljava/lang/String;)Ljava/lang/Class;#sink-UnsafeReflectiveCall -api#java/lang/ClassLoader#loadClass#(Ljava/lang/String;Z)Ljava/lang/Class;#sink-UnsafeReflectiveCall -api#java/lang/Class#forName#(Ljava/lang/Module;Ljava/lang/String;)Ljava/lang/Class;#sink-UnsafeReflectiveCall -api#java/lang/ClassLoader#loadClass#(Ljava/lang/Module;Ljava/lang/String;)Ljava/lang/Class;#sink-UnsafeReflectiveCall -api#java/lang/Runtime#load##sink-LoadArbitraryLibrary -api#java/lang/Runtime#loadLibrary##sink-LoadArbitraryLibrary -api#java/lang/System#load##sink-LoadArbitraryLibrary -api#java/lang/System#loadLibrary##sink-LoadArbitraryLibrary -api#java/lang/System#mapLibraryName##sink-LoadArbitraryLibrary -api#java/lang/ClassLoader#findLibrary##sink-LoadArbitraryLibrary -api#java/lang/ProcessImpl#start##sink-OsCommandInjection -api#java/lang/ProcessBuilder#start##sink-OsCommandInjection -api#javax/naming/Context#lookup#(Ljava/lang/String;)Ljava/lang/Object;#sink-RemoteJNDILookup -api#javax/naming/Context#lookupLink#(Ljava/lang/String;)Ljava/lang/Object;#sink-RemoteJNDILookup -api#javax/naming/directory/DirContext#search#(Ljava/lang/String;Ljavax/naming/directory/Attributes;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/DirContext#search#(Ljava/lang/String;Ljavax/naming/directory/Attributes;[Ljava/lang/String;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/DirContext#search#(Ljava/lang/String;Ljava/lang/String;Ljavax/naming/directory/SearchControls;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/DirContext#search#(Ljavax/naming/Name;Ljava/lang/String;[Ljava/lang/Object;Ljavax/naming/directory/SearchControls;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/DirContext#search#(Ljava/lang/String;Ljava/lang/String;[Ljava/lang/Object;Ljavax/naming/directory/SearchControls;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/InitialDirContext#search#(Ljava/lang/String;Ljavax/naming/directory/Attributes;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/InitialDirContext#search#(Ljava/lang/String;Ljavax/naming/directory/Attributes;[Ljava/lang/String;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/InitialDirContext#search#(Ljava/lang/String;Ljava/lang/String;Ljavax/naming/directory/SearchControls;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/InitialDirContext#search#(Ljavax/naming/Name;Ljava/lang/String;[Ljava/lang/Object;Ljavax/naming/directory/SearchControls;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/naming/directory/InitialDirContext#search#(Ljava/lang/String;Ljava/lang/String;[Ljava/lang/Object;Ljavax/naming/directory/SearchControls;)Ljavax/naming/NamingEnumeration;#sink-LdapInjection -api#javax/el/ExpressionFactory#createValueExpression##sink-ExpressionLanguageInjection -api#javax/el/ExpressionFactory#createMethodExpression##sink-ExpressionLanguageInjection -api#jakarta/el/ExpressionFactory#createValueExpression##sink-ExpressionLanguageInjection -api#jakarta/el/ExpressionFactory#createMethodExpression##sink-ExpressionLanguageInjection -api#javax/validation/ConstraintValidatorContext#buildConstraintViolationWithTemplate##sink-ExpressionLanguageInjection -api#java/io/ObjectInputStream##(Ljava/io/InputStream;)V#sink-UnsafeDeserialization -api#java/io/ObjectInputStream#readObject##sink-UnsafeDeserialization -api#java/io/ObjectInputStream#readObjectOverride##sink-UnsafeDeserialization -api#java/io/ObjectInputStream#readUnshared##sink-UnsafeDeserialization -api#javax/xml/xpath/XPath#compile##sink-XPathInjection -api#javax/xml/xpath/XPath#evaluate##sink-XPathInjection -api#javax/xml/xpath/XPath#evaluateExpression##sink-XPathInjection -api#java/util/regex/Pattern#compile#(Ljava/lang/String;I)Ljava/util/regex/Pattern;#sink-RegexInjection -api#java/util/regex/Pattern#compile#(Ljava/lang/String;)Ljava/util/regex/Pattern;#sink-RegexInjection -api#java/util/regex/Pattern#matches#(Ljava/lang/String;Ljava/lang/CharSequence;)Z#sink-RegexInjection -api#java/lang/String#matches#(Ljava/lang/String;)Z#sink-RegexInjection -api#java/lang/String#replaceAll#(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;#sink-RegexInjection -api#java/lang/String#replaceFirst#(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;#sink-RegexInjection -api#java/lang/String#split#(Ljava/lang/String;)[Ljava/lang/String;#sink-RegexInjection -api#java/lang/String#split#(Ljava/lang/String;I)[Ljava/lang/String;#sink-RegexInjection -api#java/sql/Statement#execute##sink-SqlInjection -api#java/sql/Statement#executeBatch##sink-SqlInjection -api#java/sql/Statement#executeLargeBatch##sink-SqlInjection -api#java/sql/Statement#executeLargeUpdate##sink-SqlInjection -api#java/sql/Statement#executeQuery##sink-SqlInjection -api#java/sql/Statement#executeUpdate##sink-SqlInjection -api#javax/persistence/EntityManager#createNativeQuery##sink-SqlInjection -api#java/net/SocketImpl#connect##sink-ServerSideRequestForgery -api#java/net/Socket#connect##sink-ServerSideRequestForgery -api#java/net/SocksSocketImpl#connect##sink-ServerSideRequestForgery -api#java/nio/channels/SocketChannel#connect##sink-ServerSideRequestForgery -api#sun/nio/ch/SocketAdaptor#connect##sink-ServerSideRequestForgery -api#jdk/internal/net/http/PlainHttpConnection#connect##sink-ServerSideRequestForgery -api#java/nio/file/Files#createDirectory##sink-FilePathTraversal -api#java/nio/file/Files#createDirectories##sink-FilePathTraversal -api#java/nio/file/Files#createFile##sink-FilePathTraversal -api#java/nio/file/Files#createTempDirectory##sink-FilePathTraversal -api#java/nio/file/Files#createTempFile##sink-FilePathTraversal -api#java/nio/file/Files#delete##sink-FilePathTraversal -api#java/nio/file/Files#deleteIfExists##sink-FilePathTraversal -api#java/nio/file/Files#lines##sink-FilePathTraversal -api#java/nio/file/Files#newByteChannel##sink-FilePathTraversal -api#java/nio/file/Files#newBufferedReader##sink-FilePathTraversal -api#java/nio/file/Files#newBufferedWriter##sink-FilePathTraversal -api#java/nio/file/Files#readString##sink-FilePathTraversal -api#java/nio/file/Files#readAllBytes##sink-FilePathTraversal -api#java/nio/file/Files#readAllLines##sink-FilePathTraversal -api#java/nio/file/Files#readSymbolicLink##sink-FilePathTraversal -api#java/nio/file/Files#write##sink-FilePathTraversal -api#java/nio/file/Files#writeString##sink-FilePathTraversal -api#java/nio/file/Files#newInputStream##sink-FilePathTraversal -api#java/nio/file/Files#newOutputStream##sink-FilePathTraversal -api#java/nio/file/probeContentType#open##sink-FilePathTraversal -api#java/nio/channels/FileChannel#open##sink-FilePathTraversal -api#java/nio/file/Files#copy##sink-FilePathTraversal -api#java/nio/file/Files#mismatch##sink-FilePathTraversal -api#java/nio/file/Files#move##sink-FilePathTraversal -api#java/io/FileReader###sink-FilePathTraversal -api#java/io/FileWriter###sink-FilePathTraversal -api#java/io/FileInputStream###sink-FilePathTraversal -api#java/io/FileOutputStream###sink-FilePathTraversal -api#java/util/Scanner###sink-FilePathTraversal diff --git a/crs/ssmode-lpg.toml b/crs/ssmode-lpg.toml deleted file mode 100644 index 610b45d57..000000000 --- a/crs/ssmode-lpg.toml +++ /dev/null @@ -1,323 +0,0 @@ -[mock-java.cpv_0] -path = "repo/src/main/java/com/aixcc/mock_java/App.java" -line = 17 -class_name = "com.aixcc.mock_java.App" -harness = ["OssFuzz1"] - -[batik.cpv_0] -path = "repo/src/main/java/org/apache/batik/util/ParsedURLData.java" -line = 554 -class_name = "org.apache.batik.util.ParsedURLData" -harness = ["BatikOne", "BatikOneFDP"] - -[bcel.cpv_0] -path = "repo/src/main/java/org/apache/bcel/verifier/statics/Pass2Verifier.java" -line = 973 -class_name = "org.apache.bcel.verifier.statics.Pass2Verifier$CPESSC_Visitor" -harness = ["BCELOne", "BCELOneFDP"] - -[beanutils.cpv_0] -path = "repo/src/main/java/com/aixcc/beanutils/harnesses/one/BeanUtilsOne.java" -line = 79 -class_name = "com.aixcc.beanutils.harnesses.one.BeanUtilsOne" -harness = "BeanUtilsOne" - -[cron-utils.cpv_0] -path = "repo/src/main/java/com/aixcc/cronutils/harnesses/one/CronUtilsOne.java" -line = 35 -class_name = "com.aixcc.cronutils.harnesses.one.CronUtilsOne" -harness = ["CronUtilsOne", "CronUtilsOneFDP"] - -[cxf.cpv_0] -path = "repo/src/main/java/com/aixcc/cxf/harnesses/one/CXFOne.java" -line = 35 -class_name = "com.aixcc.cxf.harnesses.one.CXFOne" -harness = "CXFOne" - -[cxf.cpv_1] -path = "repo/src/main/java/com/aixcc/cxf/harnesses/two/CXFTwo.java" -line = 36 -class_name = "com.aixcc.cxf.harnesses.two.CXFTwo" -harness = "CXFTwo" - -[cxf.cpv_2] -path = "repo/src/main/java/org/apache/cxf/jaxrs/utils/ResourceUtils.java" -line = 577 -class_name = "org.apache.cxf.jaxrs.utils.ResourceUtils" -harness = ["CXFThree", "CXFThreeFDP"] - -[geonetwork.cpv_0] -path = "repo/src/main/java/org/fao/geonet/kernel/harvest/harvester/localfilesystem/LocalFilesystemHarvester.java" -line = 242 -class_name = "org.fao.geonet.kernel.harvest.harvester.localfilesystem.LocalFilesystemHarvester" -harness = ["GeonetworkOne", "GeonetworkOneFDP"] - -[htmlunit.cpv_0] -path = "repo/src/main/java/org/htmlunit/javascript/host/xml/XSLTProcessor.java" -line = 262 -class_name = "org.htmlunit.javascript.host.xml.XSLTProcessor" -harness = ["HtmlunitOne", "HtmlunitOneFDP"] - -[imaging.cpv_0] -path = "repo/src/main/java/org/apache/commons/imaging/formats/png/chunks/PngChunkZtxt.java" -line = 62 -class_name = "org.apache.commons.imaging.formats.png.chunks.PngChunkZtxt" -harness = ["ImagingOne", "ImagingOneFDP"] - -[imaging.cpv_2] -path = "repo/src/main/java/org/apache/commons/imaging/formats/jpeg/segments/PatSegment.java" -line = 93 -class_name = "org.apache.commons.imaging.formats.jpeg.segments.PatSegment" -harness = ["ImagingTwo", "ImagingTwoFDP"] - -[jakarta-mail-api.cpv_0] -path = "repo/src/main/java/org/eclipse/angus/mail/handlers/text_xml.java" -line = 95 -class_name = "org.eclipse.angus.mail.handlers.text_xml" -harness = ["MailApiHarnessOne", "MailApiHarnessOneFDP"] - -[jenkins.cpv_0] -path = "repo/src/main/java/io/jenkins/plugins/UtilPlug/UtilMain.java" -line = 194 -class_name = "io.jenkins.plugins.UtilPlug.UtilMain" -harness = ["JenkinsTwo", "JenkinsTwoFDP"] - -[jenkins.cpv_10] -path = "repo/src/main/java/io/jenkins/plugins/toyplugin/SecretMessage.java" -line = 98 -class_name = "io.jenkins.plugins.toyplugin.SecretMessage" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[jenkins.cpv_11] -path = "repo/src/main/java/io/jenkins/plugins/toyplugin/AccessFilter.java" -line = 87 -class_name = "io.jenkins.plugins.toyplugin.AccessFilter" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[jenkins.cpv_12] -path = "repo/src/main/java/com/sonyericsson/jenkins/plugins/bfa/CauseManagement.java" -line = 245 -class_name = "com.sonyericsson.jenkins.plugins.bfa.CauseManagement" -harness = ["JenkinsFour", "JenkinsFourFDP"] - -[jenkins.cpv_13] -path = "repo/src/main/java/com/sonyericsson/jenkins/plugins/bfa/db/LocalFileKnowledgeBase.java" -line = 130 -class_name = "com.sonyericsson.jenkins.plugins.bfa.db.LocalFileKnowledgeBase" -harness = ["JenkinsFour", "JenkinsFourFDP"] - -[jenkins.cpv_14] -path = "repo/src/main/java/hudson/plugins/emailext/EmailConfig.java" -line = 102 -class_name = "hudson.plugins.emailext.EmailConfig" -harness = ["JenkinsFive", "JenkinsFiveFDP"] - -[jenkins.cpv_2] -path = "repo/src/main/java/hudson/PluginManager.java" -line = 1640 -class_name = "hudson.PluginManager" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[jenkins.cpv_3] -path = "repo/src/main/java/hudson/ProxyConfiguration.java" -line = 593 -class_name = "hudson.ProxyConfiguration$DescriptorImpl2" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[jenkins.cpv_4] -path = "repo/src/main/java/io/jenkins/plugins/coverage/CoverageProcessor.java" -line = 733 -class_name = "io.jenkins.plugins.coverage.CoverageProcessor" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[jenkins.cpv_6] -path = "repo/src/main/java/io/jenkins/plugins/toyplugin/StateMonitor.java" -line = 58 -class_name = "io.jenkins.plugins.toyplugin.StateMonitor" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[jenkins.cpv_7] -path = "repo/src/main/java/hudson/Proc.java" -line = 252 -class_name = "hudson.Proc$LocalProc" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[jenkins.cpv_8] -path = "repo/src/main/java/io/jenkins/plugins/toyplugin/AuthAction.java" -line = 229 -class_name = "io.jenkins.plugins.toyplugin.AuthAction" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[oripa.cpv_0] -path = "repo/src/main/java/oripa/persistence/doc/loader/LoaderXML.java" -line = 44 -class_name = "oripa.persistence.doc.loader.LoaderXML" -harness = ["OripaOne", "OripaOneFDP"] - -[r2-apache-commons-compress.cpv_0] -path = "repo/src/main/java/org/apache/commons/compress/archivers/zip/ZipFile.java" -line = 1552 -class_name = "org.apache.commons.compress.archivers.zip.ZipFile" -harness = "CompressZipFuzzer" - -[r2-apache-commons-compress.cpv_2] -path = "repo/src/main/java/org/apache/commons/compress/archivers/tar/TarArchiveEntry.java" -line = 1559 -class_name = "org.apache.commons.compress.archivers.tar.TarArchiveEntry" -harness = "CompressTarFuzzer" - -[r2-zookeeper.cpv_0] -path = "repo/src/main/java/org/apache/zookeeper/server/DataTree.java" -line = 555 -class_name = "org.apache.zookeeper.server.DataTree" -harness = "MultiProcessTxnFuzzer" - -[r3-tika-delta-01.cpv_0] -path = "repo/src/main/java/org/apache/tika/parser/external/ExternalParser.java" -line = 314 -class_name = "org.apache.tika.parser.external.ExternalParser" -harness = "HtmlParserFuzzer" - -[r3-tika-delta-03.cpv_0] -path = "repo/src/main/java/org/apache/tika/parser/threedxml/ThreeDXMLParser.java" -line = 141 -class_name = "org.apache.tika.parser.threedxml.ThreeDXMLParser" -harness = "ThreeDXMLParserFuzzer" - -[r3-tika-delta-04.cpv_0] -path = "repo/src/main/java/org/apache/tika/pipes/emitter/fs/FileSystemEmitter.java" -line = 155 -class_name = "org.apache.tika.pipes.emitter.fs.FileSystemEmitter" -harness = "TikaAppRUnpackerFuzzer" - -[r3-tika-delta-05.cpv_0] -path = "repo/src/main/java/org/apache/tika/cli/TikaCLI.java" -line = 1111 -class_name = "org.apache.tika.cli.TikaCLI$FileEmbeddedDocumentExtractor" -harness = "TikaAppUnpackerFuzzer" - -[r3-tika.cpv_0] -path = "repo/src/main/java/org/apache/tika/parser/xliff/XLIFF12Parser.java" -line = 88 -class_name = "org.apache.tika.parser.xliff.XLIFF12Parser" -harness = "XliffParserFuzzer" - -[r3-tika.cpv_1] -path = "repo/src/main/java/org/apache/tika/parser/microsoft/rtf/RTFEmbObjHandler.java" -line = 217 -class_name = "org.apache.tika.parser.microsoft.rtf.RTFEmbObjHandler" -harness = "RTFParserFuzzer" - -[r3-tika.cpv_2] -path = "repo/src/main/java/org/apache/tika/cli/TikaUntar.java" -line = 47 -class_name = "org.apache.tika.cli.TikaUntar" -harness = "TikaAppUntarringFuzzer" - -[r3-tika.cpv_4] -path = "repo/src/main/java/org/apache/tika/utils/ProcessUtils.java" -line = 94 -class_name = "org.apache.tika.utils.ProcessUtils" -harness = "TikaAppUnpackerFuzzer" - -[r3_5-jackson-databind.cpv_0] -path = "repo/src/main/java/com/fasterxml/jackson/databind/ser/BeanPropertyWriter.java" -line = 688 -class_name = "com.fasterxml.jackson.databind.ser.BeanPropertyWriter" -harness = "AdaLObjectReader3Fuzzer" - -[rdf4j.cpv_0] -path = "repo/src/main/java/org/eclipse/rdf4j/common/xml/SimpleSAXParser.java" -line = 221 -class_name = "org.eclipse.rdf4j.common.xml.SimpleSAXParser" -harness = ["Rdf4jOne", "Rdf4jOneFDP"] - -[rdf4j.cpv_1] -path = "repo/src/main/java/org/eclipse/rdf4j/rio/rdfxml/RDFXMLParser.java" -line = 328 -class_name = "org.eclipse.rdf4j.rio.rdfxml.RDFXMLParser" -harness = ["Rdf4jOne", "Rdf4jOneFDP"] - -[shiro.cpv_0] -path = "repo/src/main/java/org/apache/shiro/io/DefaultSerializer.java" -line = 77 -class_name = "org.apache.shiro.io.DefaultSerializer" -harness = "ShiroOne" - -[tika.cpv_1] -path = "repo/src/main/java/org/apache/tika/parser/external/ExternalParser.java" -line = 161 -class_name = "org.apache.tika.parser.external.ExternalParser" -harness = ["TikaTwo", "TikaTwoFDP"] - -[widoco.cpv_0] -path = "repo/src/main/java/widoco/WidocoUtils.java" -line = 269 -class_name = "widoco.WidocoUtils" -harness = ["WidocoOne", "WidocoOneFDP"] - -[xstream.cpv_2] -path = "repo/src/main/java/com/thoughtworks/xstream/io/xml/DomDriver.java" -line = 108 -class_name = "com.thoughtworks.xstream.io.xml.DomDriver" -harness = "XmlFuzzer" - -[xstream.cpv_4] -path = "repo/src/main/java/com/thoughtworks/xstream/converters/extended/StringInterpolationConverter.java" -line = 195 -class_name = "com.thoughtworks.xstream.converters.extended.StringInterpolationConverter" -harness = "XmlFuzzer" - -[ztzip.cpv_0] -path = "repo/src/main/java/org/zeroturnaround/zip/commons/FileUtils.java" -line = 54 -class_name = "org.zeroturnaround.zip.commons.FileUtils" -harness = ["ZTZIPOne", "ZTZIPOneFDP"] - -[fuzzy.cpv_0] -path = "repo/src/main/java/me/xdrop/fuzzywuzzy/Config.java" -line = 84 -class_name = "me.xdrop.fuzzywuzzy.Config" -harness = ["FuzzyOne", "FuzzyOneFDP"] - -[jenkins.cpv_5] -path = "repo/src/main/java/io/jenkins/plugins/toyplugin/UserNameAction.java" -line = 74 -class_name = "io.jenkins.plugins.toyplugin.UserNameAction" -harness = ["JenkinsThree", "JenkinsThreeFDP"] - -[pac4j.cpv_0] -path = "repo/src/main/java/org/pac4j/core/util/JavaSerializationHelper.java" -line = 82 -class_name = "org.pac4j.core.util.JavaSerializationHelper" -harness = ["Pac4jOne", "Pac4jOneFDP"] - -[r1-zookeeper.cpv_3] -path = "repo/src/main/java/org/apache/zookeeper/server/DataTree.java" -line = 1077 -class_name = "org.apache.zookeeper.server.DataTree" -harness = "MultiProcessTxnFuzzer" - -[r1-zookeeper.cpv_4] -path = "repo/src/main/java/org/apache/zookeeper/server/DataTree.java" -line = 2182 -class_name = "org.apache.zookeeper.server.DataTree" -harness = "MultiProcessTxnFuzzer" - -[r2-apache-commons-compress.cpv_3] -path = "repo/src/main/java/org/apache/commons/compress/archivers/examples/Expander.java" -line = 86 -class_name = "org.apache.commons.compress.archivers.examples.Expander" -harness = "ExpanderFuzzer" - -[rdf4j.cpv_2] -path = "repo/src/main/java/org/eclipse/rdf4j/common/io/IOUtil.java" -line = 387 -class_name = "org.eclipse.rdf4j.common.io.IOUtil" -harness = ["Rdf4jOne", "Rdf4jOneFDP"] - -[tika.cpv_0] -path = "repo/src/main/java/org/apache/tika/cli/TikaCLI.java" -line = 1103 -class_name = "org.apache.tika.cli.TikaCLI$FileEmbeddedDocumentExtractor" -harness = ["TikaOne", "TikaOneFDP"] diff --git a/crs/ssmode-sink.txt b/crs/ssmode-sink.txt deleted file mode 100644 index 0252a1ad2..000000000 --- a/crs/ssmode-sink.txt +++ /dev/null @@ -1,5 +0,0 @@ -######## -# Generated by crs/assets/sink-callee-apis/gen.py -# API-based sinkpoints (format: api#calleeClassName#methodName#methodDesc#markDesc) -# Coordinate-based sinkpoints (format: caller#className#methodName#methodDesc#fileName#lineNumber#bytecodeOffset#markDesc) -######## diff --git a/crs/static-analysis/src/main/java/org/gts3/atlantis/staticanalysis/ArgumentParser.java b/crs/static-analysis/src/main/java/org/gts3/atlantis/staticanalysis/ArgumentParser.java index e16eee7d6..d2fa5c09d 100644 --- a/crs/static-analysis/src/main/java/org/gts3/atlantis/staticanalysis/ArgumentParser.java +++ b/crs/static-analysis/src/main/java/org/gts3/atlantis/staticanalysis/ArgumentParser.java @@ -155,8 +155,8 @@ public ArgumentParser(String[] args) throws IOException { configFileOption.setRequired(true); options.addOption(configFileOption); - Option targetFileOption = new Option("t", "target-file", true, "Path to file with target specifications (api and coordinate format)"); - targetFileOption.setRequired(true); + Option targetFileOption = new Option("t", "target-file", true, "Path to file with target specifications (api and coordinate format). Optional; if omitted, targets come solely from --sarif-sinkpoints."); + targetFileOption.setRequired(false); options.addOption(targetFileOption); Option sarifSinkpointsOption = Option.builder() @@ -239,7 +239,9 @@ public ArgumentParser(String[] args) throws IOException { this.configFile = Path.of(cmd.getOptionValue("config")); // Target specification inputs - this.targetFile = Path.of(cmd.getOptionValue("target-file")); + this.targetFile = cmd.hasOption("target-file") + ? Path.of(cmd.getOptionValue("target-file")) + : null; this.sarifSinkpointsFile = cmd.hasOption("sarif-sinkpoints") ? Path.of(cmd.getOptionValue("sarif-sinkpoints")) : null; @@ -288,13 +290,15 @@ public ArgumentParser(String[] args) throws IOException { parseConfigFile(); parseFields(); - // Load targets from target file - try { - List lines = Files.readAllLines(this.targetFile); - targets.addAll(loadTargets(lines)); - } catch (IOException e) { - System.err.println(LOG_ERROR + "Error reading target file: " + e.getMessage()); - throw new IllegalArgumentException("Failed to read target file: " + this.targetFile, e); + // Load targets from target file (if provided) + if (this.targetFile != null) { + try { + List lines = Files.readAllLines(this.targetFile); + targets.addAll(loadTargets(lines)); + } catch (IOException e) { + System.err.println(LOG_ERROR + "Error reading target file: " + e.getMessage()); + throw new IllegalArgumentException("Failed to read target file: " + this.targetFile, e); + } } } diff --git a/oss-crs/crs.yaml b/oss-crs/crs.yaml index be7d0681b..4e291ec44 100644 --- a/oss-crs/crs.yaml +++ b/oss-crs/crs.yaml @@ -11,6 +11,7 @@ target_build_phase: - build - crs/proj - crs/src + - crs/codeql-db crs_run_phase: runner: diff --git a/oss-crs/dockerfiles/builder.Dockerfile b/oss-crs/dockerfiles/builder.Dockerfile index 1d0370db8..5addd7e39 100644 --- a/oss-crs/dockerfiles/builder.Dockerfile +++ b/oss-crs/dockerfiles/builder.Dockerfile @@ -4,6 +4,21 @@ FROM ${target_base_image} COPY --from=libcrs . /libCRS RUN /libCRS/install.sh +# CodeQL CLI for database creation during build +ENV CODEQL_VERSION=2.20.4 +ENV CODEQL_HOME=/opt/codeql +RUN set -eux; \ + apt-get update; \ + apt-get install -y --no-install-recommends wget unzip ca-certificates; \ + rm -rf /var/lib/apt/lists/*; \ + cd /tmp; \ + wget -q "https://github.com/github/codeql-cli-binaries/releases/download/v${CODEQL_VERSION}/codeql-linux64.zip"; \ + unzip -q codeql-linux64.zip; \ + mv codeql "${CODEQL_HOME}"; \ + rm codeql-linux64.zip; \ + ln -s "${CODEQL_HOME}/codeql" /usr/local/bin/codeql; \ + codeql pack download codeql/java-queries:codeql-java + RUN mkdir -p /out/crs COPY ./build.py /crs/build.py diff --git a/oss-crs/dockerfiles/runner.Dockerfile b/oss-crs/dockerfiles/runner.Dockerfile index 65a1e0f21..00d1e001d 100644 --- a/oss-crs/dockerfiles/runner.Dockerfile +++ b/oss-crs/dockerfiles/runner.Dockerfile @@ -118,6 +118,19 @@ ENV ATL_JAZZER_DIR=/classpath/atl-jazzer ENV ATL_JAZZER_LIBAFL_DIR=/classpath/atl-libafl-jazzer ENV ATL_MOCK_JAZZER_DIR=/classpath/mock-jazzer +## CRS-java atl-asm and atl-soot (must run before pip install coordinates) +COPY ./crs/prebuilt ${JAVA_CRS_SRC}/prebuilt +RUN cd ${JAVA_CRS_SRC}/prebuilt && \ + ./mvn_install.sh +ENV JACOCO_CLI_DIR=${JAVA_CRS_SRC}/prebuilt/jacococli + +## joern +COPY --from=joern_builder /opt/joern ${JAVA_CRS_SRC}/joern +ENV JOERN_DIR=${JAVA_CRS_SRC}/joern/Joern +ENV JOERN_CLI=$JOERN_DIR/joern-cli +ENV JAVA2CPG=$JOERN_DIR/joern-cli/frontends/javasrc2cpg/bin +ENV PATH=$PATH:$JAVA_HOME/bin:$JOERN_CLI:$JAVA2CPG + ## crs python package deps COPY ./crs/libs ${JAVA_CRS_SRC}/libs RUN cd ${JAVA_CRS_SRC}/libs/libFDP/libfdp && cargo update simd_cesu8 --precise 1.0.1 @@ -135,19 +148,6 @@ RUN /venv/bin/pip install --no-cache-dir \ ${JAVA_CRS_SRC}/libs/claude-code-sdk-python && \ rm -rf /root/.cache/pip -## joern -COPY --from=joern_builder /opt/joern ${JAVA_CRS_SRC}/joern -ENV JOERN_DIR=${JAVA_CRS_SRC}/joern/Joern -ENV JOERN_CLI=$JOERN_DIR/joern-cli -ENV JAVA2CPG=$JOERN_DIR/joern-cli/frontends/javasrc2cpg/bin -ENV PATH=$PATH:$JAVA_HOME/bin:$JOERN_CLI:$JAVA2CPG - -## CRS-java atl-asm and atl-soot -COPY ./crs/prebuilt ${JAVA_CRS_SRC}/prebuilt -RUN cd ${JAVA_CRS_SRC}/prebuilt && \ - ./mvn_install.sh -ENV JACOCO_CLI_DIR=${JAVA_CRS_SRC}/prebuilt/jacococli - ## jazzer-llm-augmented COPY ./crs/jazzer-llm-augmented ${JAVA_CRS_SRC}/jazzer-llm-augmented @@ -161,6 +161,12 @@ COPY ./crs/codeql ${JAVA_CRS_SRC}/codeql RUN cd ${JAVA_CRS_SRC}/codeql && \ ./init.sh +## filtering-agent (LLM-powered exploitability assessment) +COPY ./crs/filtering-agent ${JAVA_CRS_SRC}/filtering-agent +RUN cd ${JAVA_CRS_SRC}/filtering-agent && \ + if [ -f requirements.txt ]; then /venv/bin/pip install --no-cache-dir -r requirements.txt; fi && \ + rm -rf /root/.cache/pip + ## llm-poc-gen COPY ./crs/llm-poc-gen ${JAVA_CRS_SRC}/llm-poc-gen ENV PATH=${PATH}:/root/.local/bin @@ -217,13 +223,11 @@ RUN cd ${JAVA_CRS_SRC}/deepgen/jvm/stuck-point-analyzer && \ rm -rf /root/.cache/pip ## crs-java main entry -COPY ./crs/*.sh ./crs/*.py ./crs/requirements.txt ./crs/jazzer_driver_stub ./crs/crs-java.config ./crs/sink-targets.txt ${JAVA_CRS_SRC}/ +COPY ./crs/*.sh ./crs/*.py ./crs/requirements.txt ./crs/jazzer_driver_stub ./crs/crs-java.config ${JAVA_CRS_SRC}/ COPY ./crs/javacrs_modules ${JAVA_CRS_SRC}/javacrs_modules COPY ./crs/tests ${JAVA_CRS_SRC}/tests RUN /venv/bin/pip install --no-cache-dir -r ${JAVA_CRS_SRC}/requirements.txt && \ rm -rf /root/.cache/pip -ENV JAVA_CRS_SINK_TARGET_CONF=${JAVA_CRS_SRC}/sink-targets.txt -ENV JAVA_CRS_CUSTOM_SINK_YAML=${JAVA_CRS_SRC}/codeql/sink_definitions.yml ## git setup RUN git config --global --add safe.directory '*' @@ -237,12 +241,6 @@ COPY --from=aixcc_afc_builder_base /usr/local/bin/jazzer_driver /classpath/raw-j COPY --from=aixcc_afc_builder_base /usr/local/bin/jazzer_junit.jar /classpath/raw-jazzer/ COPY --from=aixcc_afc_builder_base /usr/local/lib/jazzer_api_deploy.jar /classpath/raw-jazzer/ -### NOTE: exp only -COPY crs/ssmode-lpg.toml ${JAVA_CRS_SRC}/ssmode-lpg.toml -COPY crs/ssmode-sink.txt ${JAVA_CRS_SRC}/ssmode-sink.txt -RUN mkdir -p ${JAVA_CRS_SRC}/llm-poc-gen/eval/sheet && \ - cp ${JAVA_CRS_SRC}/ssmode-lpg.toml ${JAVA_CRS_SRC}/llm-poc-gen/eval/sheet/cpv.toml - ################################################################################# ## oss-crs integration layer ################################################################################# diff --git a/oss-crs/scripts/oss_crs_entrypoint.sh b/oss-crs/scripts/oss_crs_entrypoint.sh index 3a8c7cf04..76e83fdd7 100644 --- a/oss-crs/scripts/oss_crs_entrypoint.sh +++ b/oss-crs/scripts/oss_crs_entrypoint.sh @@ -8,6 +8,7 @@ set -e libCRS download-build-output build /out libCRS download-build-output crs/proj /out/crs/proj libCRS download-build-output crs/src /out/crs/src +libCRS download-build-output crs/codeql-db /out/crs/codeql-db || echo "WARNING: No CodeQL database in build outputs" ############################################# # 2. Register output directories