diff --git a/src/Disks/DiskObjectStorage/ObjectStorages/S3/diskSettings.cpp b/src/Disks/DiskObjectStorage/ObjectStorages/S3/diskSettings.cpp index 5cdfacd87378..e770e496a1f9 100644 --- a/src/Disks/DiskObjectStorage/ObjectStorages/S3/diskSettings.cpp +++ b/src/Disks/DiskObjectStorage/ObjectStorages/S3/diskSettings.cpp @@ -224,6 +224,9 @@ getClient(const S3::URI & url, const S3Settings & settings, ContextPtr context, LOG_DEBUG(getLogger("getClient"), "Got new access tokens {} {} {}", access_key_id, secret_access_key, session_token); } } + + auto shared_cache = S3::ClientCacheRegistry::instance().getOrCreateCacheForKey(url.endpoint, url.bucket); + return S3::ClientFactory::instance().create( client_configuration, client_settings, @@ -233,7 +236,8 @@ getClient(const S3::URI & url, const S3Settings & settings, ContextPtr context, auth_settings.server_side_encryption_kms_config, auth_settings.getHeaders(), credentials_configuration, - session_token); + session_token, + shared_cache); } } diff --git a/src/IO/S3/Client.cpp b/src/IO/S3/Client.cpp index 5df6a86d4327..7adfcbf933dd 100644 --- a/src/IO/S3/Client.cpp +++ b/src/IO/S3/Client.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -219,11 +220,12 @@ std::unique_ptr Client::create( const std::shared_ptr & credentials_provider, const PocoHTTPClientConfiguration & client_configuration, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads, - const ClientSettings & client_settings) + const ClientSettings & client_settings, + const std::shared_ptr & shared_cache) { verifyClientConfiguration(client_configuration); return std::unique_ptr( - new Client(max_redirects_, std::move(sse_kms_config_), credentials_provider, client_configuration, sign_payloads, client_settings)); + new Client(max_redirects_, std::move(sse_kms_config_), credentials_provider, client_configuration, sign_payloads, client_settings, shared_cache)); } std::unique_ptr Client::clone() const @@ -258,7 +260,8 @@ Client::Client( const std::shared_ptr & credentials_provider_, const PocoHTTPClientConfiguration & client_configuration_, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads_, - const ClientSettings & client_settings_) + const ClientSettings & client_settings_, + const std::shared_ptr & shared_cache) : Aws::S3::S3Client(credentials_provider_, client_configuration_, sign_payloads_, client_settings_.use_virtual_addressing) , credentials_provider(credentials_provider_) , client_configuration(client_configuration_) @@ -298,7 +301,10 @@ Client::Client( detect_region = provider_type == ProviderType::AWS && explicit_region == Aws::Region::AWS_GLOBAL; - cache = std::make_shared(); + if (shared_cache) + cache = shared_cache; + else + cache = std::make_shared(); ClientCacheRegistry::instance().registerClient(cache); ProfileEvents::increment(ProfileEvents::S3Clients); @@ -321,7 +327,7 @@ Client::Client( , sse_kms_config(other.sse_kms_config) , log(getLogger("S3Client")) { - cache = std::make_shared(*other.cache); + cache = other.cache; ClientCacheRegistry::instance().registerClient(cache); logConfiguration(); @@ -1108,37 +1114,77 @@ void ClientCache::clearCache() void ClientCacheRegistry::registerClient(const std::shared_ptr & client_cache) { std::lock_guard lock(clients_mutex); - auto [it, inserted] = client_caches.emplace(client_cache.get(), client_cache); - if (!inserted) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Same S3 client registered twice"); + auto it = client_caches.find(client_cache.get()); + if (it != client_caches.end()) + { + ++it->second.second; + return; + } + client_caches.emplace(client_cache.get(), std::pair{std::weak_ptr(client_cache), size_t(1)}); } void ClientCacheRegistry::unregisterClient(ClientCache * client) { std::lock_guard lock(clients_mutex); - auto erased = client_caches.erase(client); - if (erased == 0) + auto it = client_caches.find(client); + if (it == client_caches.end()) throw Exception(ErrorCodes::LOGICAL_ERROR, "Can't unregister S3 client, either it was already unregistered or not registered at all"); + if (--it->second.second == 0) + client_caches.erase(it); } -void ClientCacheRegistry::clearCacheForAll() +void ClientCacheRegistry::pruneExpiredCachesLocked() { - std::lock_guard lock(clients_mutex); + std::erase_if(cache_by_endpoint_bucket, [](const auto & pair) { return pair.second.expired(); }); +} - for (auto it = client_caches.begin(); it != client_caches.end();) +std::shared_ptr ClientCacheRegistry::getOrCreateCacheForKey(const std::string & endpoint, const std::string & bucket) +{ + SipHash hash; + hash.update(endpoint.size()); + hash.update(endpoint); + hash.update(bucket); + UInt128 key = hash.get128(); + + std::lock_guard lock(cache_by_key_mutex); + if (auto it = cache_by_endpoint_bucket.find(key); it != cache_by_endpoint_bucket.end()) { - if (auto locked_client = it->second.lock(); locked_client) - { - locked_client->clearCache(); - ++it; - } - else + if (auto cached = it->second.lock(); cached) + return cached; + cache_by_endpoint_bucket.erase(it); + } + auto cache = std::make_shared(); + cache_by_endpoint_bucket[key] = cache; + + pruneExpiredCachesLocked(); + + return cache; +} + +void ClientCacheRegistry::clearCacheForAll() +{ + { + std::lock_guard lock(clients_mutex); + + for (auto it = client_caches.begin(); it != client_caches.end();) { - LOG_INFO(getLogger("ClientCacheRegistry"), "Deleting leftover S3 client cache"); - it = client_caches.erase(it); + if (auto locked_client = it->second.first.lock(); locked_client) + { + locked_client->clearCache(); + ++it; + } + else + { + LOG_INFO(getLogger("ClientCacheRegistry"), "Deleting leftover S3 client cache"); + it = client_caches.erase(it); + } } } + { + std::lock_guard lock(cache_by_key_mutex); + pruneExpiredCachesLocked(); + } } ClientFactory::ClientFactory() @@ -1183,7 +1229,8 @@ std::unique_ptr ClientFactory::create( // NOLINT ServerSideEncryptionKMSConfig sse_kms_config, HTTPHeaderEntries headers, CredentialsConfiguration credentials_configuration, - const String & session_token) + const String & session_token, + const std::shared_ptr & shared_cache) { PocoHTTPClientConfiguration client_configuration = cfg_; client_configuration.updateSchemeAndRegion(); @@ -1237,7 +1284,8 @@ std::unique_ptr ClientFactory::create( // NOLINT client_configuration, // Client configuration. client_settings.is_s3express_bucket ? Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::RequestDependent : Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, - client_settings); + client_settings, + shared_cache); } PocoHTTPClientConfiguration ClientFactory::createClientConfiguration( // NOLINT diff --git a/src/IO/S3/Client.h b/src/IO/S3/Client.h index 4a65239a8582..dcde0e180978 100644 --- a/src/IO/S3/Client.h +++ b/src/IO/S3/Client.h @@ -32,6 +32,7 @@ struct ServerSideEncryptionKMSConfig #include #include #include +#include #include #include @@ -77,11 +78,17 @@ class ClientCacheRegistry void registerClient(const std::shared_ptr & client_cache); void unregisterClient(ClientCache * client); void clearCacheForAll(); + std::shared_ptr getOrCreateCacheForKey(const std::string & endpoint, const std::string & bucket); + private: ClientCacheRegistry() = default; + void pruneExpiredCachesLocked() TSA_REQUIRES(cache_by_key_mutex); + std::mutex clients_mutex; - std::unordered_map> client_caches TSA_GUARDED_BY(clients_mutex); + std::unordered_map, size_t>> client_caches TSA_GUARDED_BY(clients_mutex); + std::mutex cache_by_key_mutex; + std::unordered_map, UInt128Hash> cache_by_endpoint_bucket TSA_GUARDED_BY(cache_by_key_mutex); }; bool isS3ExpressEndpoint(const std::string & endpoint); @@ -128,7 +135,8 @@ class Client : private Aws::S3::S3Client const std::shared_ptr & credentials_provider, const PocoHTTPClientConfiguration & client_configuration, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads, - const ClientSettings & client_settings); + const ClientSettings & client_settings, + const std::shared_ptr & shared_cache = nullptr); std::unique_ptr clone() const; @@ -240,6 +248,9 @@ class Client : private Aws::S3::S3Client const PocoHTTPClientConfiguration & getClientConfiguration() const { return client_configuration; } + /// For testing purposes only + ClientCache * getRawCache() const { return cache.get(); } + protected: // visible for testing Client(size_t max_redirects_, @@ -247,7 +258,8 @@ class Client : private Aws::S3::S3Client const std::shared_ptr & credentials_provider_, const PocoHTTPClientConfiguration & client_configuration, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy sign_payloads, - const ClientSettings & client_settings_); + const ClientSettings & client_settings_, + const std::shared_ptr & shared_cache = nullptr); private: Client( @@ -346,7 +358,8 @@ class ClientFactory ServerSideEncryptionKMSConfig sse_kms_config, HTTPHeaderEntries headers, CredentialsConfiguration credentials_configuration, - const String & session_token = ""); + const String & session_token = "", + const std::shared_ptr & shared_cache = nullptr); PocoHTTPClientConfiguration createClientConfiguration( const String & force_region, diff --git a/src/IO/S3/tests/gtest_aws_s3_client.cpp b/src/IO/S3/tests/gtest_aws_s3_client.cpp index a2b662ff83d5..17456ef0e8cc 100644 --- a/src/IO/S3/tests/gtest_aws_s3_client.cpp +++ b/src/IO/S3/tests/gtest_aws_s3_client.cpp @@ -539,4 +539,128 @@ TEST(IOTestAwsS3Client, AssumeRole) } } +TEST(IOTestAwsS3Client, ClientCacheRegistryGetOrCreateCacheForKey) +{ + auto & registry = DB::S3::ClientCacheRegistry::instance(); + + std::shared_ptr cache_ab1 = registry.getOrCreateCacheForKey("endpoint1", "bucket1"); + std::shared_ptr cache_ab2 = registry.getOrCreateCacheForKey("endpoint1", "bucket1"); + EXPECT_EQ(cache_ab1.get(), cache_ab2.get()) << "Same (endpoint, bucket) should return the same cache"; + + std::shared_ptr cache_b1 = registry.getOrCreateCacheForKey("endpoint1", "bucket2"); + EXPECT_NE(cache_ab1.get(), cache_b1.get()) << "Different bucket should return different cache"; + + std::shared_ptr cache_e2 = registry.getOrCreateCacheForKey("endpoint2", "bucket1"); + EXPECT_NE(cache_ab1.get(), cache_e2.get()) << "Different endpoint should return different cache"; + + auto cache_concat1 = registry.getOrCreateCacheForKey("ab", "c"); + auto cache_concat2 = registry.getOrCreateCacheForKey("a", "bc"); + EXPECT_NE(cache_concat1.get(), cache_concat2.get()) + << "Pairs with identical concatenation but different boundary must not share a cache"; +} + +TEST(IOTestAwsS3Client, ClientSharesCacheWithClone) +{ + DB::RemoteHostFilter remote_host_filter; + DB::S3::URI uri("https://s3.eu-central-1.amazonaws.com/my-bucket/key"); + DB::S3::PocoHTTPClientConfiguration client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration( + "eu-central-1", + remote_host_filter, + 10, + DB::S3::PocoHTTPClientConfiguration::RetryStrategy{.max_retries = 0}, + true, + true, + false, + false, + {}, + {}, + "https"); + client_configuration.endpointOverride = uri.endpoint; + + DB::S3::ClientSettings client_settings{ + .use_virtual_addressing = uri.is_virtual_hosted_style, + .disable_checksum = false, + .gcs_issue_compose_request = false, + .is_s3express_bucket = false, + }; + + auto shared_cache = DB::S3::ClientCacheRegistry::instance().getOrCreateCacheForKey(uri.endpoint, uri.bucket); + std::unique_ptr client = DB::S3::ClientFactory::instance().create( + client_configuration, + client_settings, + "access", + "secret", + "", + {}, + {}, + DB::S3::CredentialsConfiguration{.use_environment_credentials = false, .use_insecure_imds_request = false}, + "", + shared_cache); + + ASSERT_TRUE(client); + std::unique_ptr clone = client->clone(); + ASSERT_TRUE(clone); + + EXPECT_EQ(client->getRawCache(), shared_cache.get()) << "Client should use the shared cache"; + EXPECT_EQ(clone->getRawCache(), client->getRawCache()) << "Clone should share the same cache as original"; +} + +TEST(IOTestAwsS3Client, TwoClientsWithSharedCacheUnregisterRefcount) +{ + DB::RemoteHostFilter remote_host_filter; + DB::S3::URI uri("https://s3.us-east-1.amazonaws.com/another-bucket/key"); + DB::S3::PocoHTTPClientConfiguration client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration( + "us-east-1", + remote_host_filter, + 10, + DB::S3::PocoHTTPClientConfiguration::RetryStrategy{.max_retries = 0}, + true, + true, + false, + false, + {}, + {}, + "https"); + client_configuration.endpointOverride = uri.endpoint; + + DB::S3::ClientSettings client_settings{ + .use_virtual_addressing = uri.is_virtual_hosted_style, + .disable_checksum = false, + .gcs_issue_compose_request = false, + .is_s3express_bucket = false, + }; + + auto shared_cache = DB::S3::ClientCacheRegistry::instance().getOrCreateCacheForKey(uri.endpoint, uri.bucket); + std::unique_ptr client1 = DB::S3::ClientFactory::instance().create( + client_configuration, + client_settings, + "ak", + "sk", + "", + {}, + {}, + DB::S3::CredentialsConfiguration{.use_environment_credentials = false, .use_insecure_imds_request = false}, + "", + shared_cache); + std::unique_ptr client2 = DB::S3::ClientFactory::instance().create( + client_configuration, + client_settings, + "ak", + "sk", + "", + {}, + {}, + DB::S3::CredentialsConfiguration{.use_environment_credentials = false, .use_insecure_imds_request = false}, + "", + shared_cache); + + ASSERT_TRUE(client1); + ASSERT_TRUE(client2); + EXPECT_EQ(client1->getRawCache(), client2->getRawCache()); + + client1.reset(); + client2.reset(); + // If refcount was wrong, unregisterClient would throw when the second client is destroyed +} + #endif