Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk_v2/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ set(FOUNDRY_LOCAL_SOURCES
src/download/model_registry_client.cc
src/ep_detection/cuda_ep_bootstrapper.cc
src/ep_detection/ep_detector.cc
src/ep_detection/ep_utils.cc
src/ep_detection/runtime_version_info.cc
src/ep_detection/webgpu_ep_bootstrapper.cc
src/exception.cc
Expand Down
37 changes: 9 additions & 28 deletions sdk_v2/cpp/src/ep_detection/cuda_ep_bootstrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
// Licensed under the MIT License.
#include "ep_detection/cuda_ep_bootstrapper.h"

#include "ep_detection/ep_utils.h"
#include "logger.h"
#include "util/file_lock.h"
#include "http/http_download.h"
#include "util/sha256.h"
#include "util/zip_extract.h"

#include <fmt/format.h>
Expand Down Expand Up @@ -61,31 +61,6 @@ constexpr ExpectedBinary kExpectedBinaries[] = {
constexpr const char* kRegistrationName = "Foundry.CUDA";
constexpr const char* kCudaProviderDll = "onnxruntime_providers_cuda.dll";

/// Verify all expected binaries exist and have correct SHA256 hashes.
bool VerifyPackage(const std::filesystem::path& dir, fl::ILogger& logger) {
for (const auto& expected : kExpectedBinaries) {
auto file_path = dir / expected.filename;

if (!std::filesystem::exists(file_path)) {
return false;
}

auto hash = fl::Sha256File(file_path);

// Case-insensitive comparison
std::string expected_hash(expected.sha256);
if (!std::equal(hash.begin(), hash.end(), expected_hash.begin(), expected_hash.end(),
[](char a, char b) { return std::toupper(a) == std::toupper(b); })) {
logger.Log(fl::LogLevel::Warning,
fmt::format("CUDA EP: hash mismatch for {}: got {}, expected {}",
expected.filename, hash, expected.sha256));
return false;
}
}

return true;
}

} // anonymous namespace

namespace fl {
Expand Down Expand Up @@ -127,7 +102,10 @@ bool CudaEpBootstrapper::DownloadAndRegister(bool force,
FileLock lock(lock_path);

// Check if package already exists and is valid
if (VerifyPackage(ep_dir, logger)) {
if (fl::VerifyEpPackage(ep_dir,
{{kExpectedBinaries[0].filename, kExpectedBinaries[0].sha256},
{kExpectedBinaries[1].filename, kExpectedBinaries[1].sha256}},
"CUDA EP", logger)) {
logger.Log(LogLevel::Information, "CUDA EP: package already valid, skipping download");
} else {
// Clean up any partial install
Expand Down Expand Up @@ -170,7 +148,10 @@ bool CudaEpBootstrapper::DownloadAndRegister(bool force,
std::filesystem::remove(zip_path);

// Verify
if (!VerifyPackage(ep_dir, logger)) {
if (!fl::VerifyEpPackage(ep_dir,
{{kExpectedBinaries[0].filename, kExpectedBinaries[0].sha256},
{kExpectedBinaries[1].filename, kExpectedBinaries[1].sha256}},
"CUDA EP", logger)) {
logger.Log(LogLevel::Warning, "CUDA EP: verification failed after download");
return false;
}
Expand Down
5 changes: 4 additions & 1 deletion sdk_v2/cpp/src/ep_detection/ep_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ EpDownloadResult EpDetector::DownloadAndRegisterEps(const std::vector<std::strin

logger_.Log(LogLevel::Information, "Downloading and registering EP: " + bs->Name());

if (bs->DownloadAndRegister(/*force=*/true, wrapped_cb, logger_)) {
// Reuse previously downloaded EP packages unless the caller explicitly asks
// for a forced refresh. Downloading every time made the bootstrapper
// re-fetch and re-register EPs on every invocation.
if (bs->DownloadAndRegister(/*force=*/false, wrapped_cb, logger_)) {
result.registered_eps.push_back(bs->Name());

// Update cached registration state in place under the cache lock so
Expand Down
43 changes: 43 additions & 0 deletions sdk_v2/cpp/src/ep_detection/ep_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "ep_detection/ep_utils.h"

#include "logger.h"
#include "util/sha256.h"

#include <fmt/format.h>

#include <algorithm>
#include <cctype>
#include <string>

namespace fl {

bool VerifyEpPackage(
const std::filesystem::path& dir,
std::initializer_list<std::pair<std::string_view, std::string_view>> expected,
std::string_view ep_name,
ILogger& logger) {
for (const auto& [filename, expected_hash] : expected) {
auto file_path = dir / filename;

if (!std::filesystem::exists(file_path)) {
return false;
}

auto hash = Sha256File(file_path);

// Case-insensitive hex comparison
if (!std::equal(hash.begin(), hash.end(), expected_hash.begin(), expected_hash.end(),
[](char a, char b) { return std::toupper(a) == std::toupper(b); })) {
logger.Log(LogLevel::Warning,
fmt::format("{}: hash mismatch for {}: got {}, expected {}",
ep_name, filename, hash, expected_hash));
return false;
}
}

return true;
}

} // namespace fl
27 changes: 27 additions & 0 deletions sdk_v2/cpp/src/ep_detection/ep_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include <filesystem>
#include <initializer_list>
#include <string_view>
#include <utility>

namespace fl {

class ILogger;

/// Verify a set of binaries in @p dir all exist and match their expected SHA-256 hashes.
///
/// @param dir Directory containing the extracted EP binaries.
/// @param expected List of (filename, expected_sha256_hex) pairs.
/// @param ep_name EP name used in warning log messages (e.g. "CUDA EP").
/// @param logger Logger for diagnostic output.
/// @return true if every file exists and its hash matches; false otherwise.
bool VerifyEpPackage(
const std::filesystem::path& dir,
std::initializer_list<std::pair<std::string_view, std::string_view>> expected,
std::string_view ep_name,
ILogger& logger);

} // namespace fl
Loading