|
16 | 16 | #include <shared_mutex> |
17 | 17 | #include <functional> |
18 | 18 | #include <list> |
| 19 | +#include <optional> |
| 20 | +#include <algorithm> // For std::max |
19 | 21 |
|
20 | 22 | template <typename Key, typename Value, typename Hash = std::hash<Key>> |
21 | | -class ConcurrentHashMap { |
| 23 | +class ConcurrentHashMap |
| 24 | +{ |
22 | 25 | private: |
23 | 26 | // Define the structure of each node in the hash map |
24 | | - struct Node { |
| 27 | + struct Node |
| 28 | + { |
25 | 29 | Key key; |
26 | 30 | Value value; |
27 | 31 | }; |
28 | 32 |
|
29 | 33 | // Define the hash map buckets and associated mutexes |
30 | 34 | std::vector<std::list<Node>> buckets; |
31 | | - std::vector<std::shared_mutex> mutexes; |
| 35 | + mutable std::vector<std::shared_mutex> mutexes; |
32 | 36 | Hash hashFunction; |
33 | 37 |
|
34 | | - // Get the mutex for a given key |
35 | | - std::shared_mutex& getMutex(const Key& key) { |
| 38 | + // Get mutex for a key |
| 39 | + std::shared_mutex &getMutex(const Key &key) const |
| 40 | + { |
36 | 41 | std::size_t hashValue = hashFunction(key); |
37 | 42 | return mutexes[hashValue % mutexes.size()]; |
38 | 43 | } |
39 | 44 |
|
40 | 45 | public: |
41 | | - explicit ConcurrentHashMap(std::size_t num_buckets = 16) : buckets(num_buckets), mutexes(num_buckets) {} |
| 46 | + explicit ConcurrentHashMap(std::size_t num_buckets = 16) |
| 47 | + : buckets(std::max<std::size_t>(1, num_buckets)), |
| 48 | + mutexes(std::max<std::size_t>(1, num_buckets)) {} |
42 | 49 |
|
43 | 50 | // Insert a key-value pair into the hash map |
44 | | - void insert(const Key& key, const Value& value) { |
| 51 | + void insert(const Key &key, const Value &value) |
| 52 | + { |
45 | 53 | std::unique_lock lock(getMutex(key)); |
46 | | - std::size_t hashValue = hashFunction(key); |
47 | | - std::size_t index = hashValue % buckets.size(); |
48 | | - |
49 | | - auto& bucket = buckets[index]; |
50 | | - auto it = std::find_if(bucket.begin(), bucket.end(), [&](const Node& node) { return node.key == key; }); |
51 | | - |
52 | | - if (it != bucket.end()) { |
53 | | - // Update existing key |
| 54 | + std::size_t index = hashFunction(key) % buckets.size(); |
| 55 | + auto &bucket = buckets[index]; |
| 56 | + auto it = std::find_if(bucket.begin(), bucket.end(), |
| 57 | + [&](const Node & node) |
| 58 | + { |
| 59 | + return node.key == key; |
| 60 | + }); |
| 61 | + |
| 62 | + if(it != bucket.end()) |
| 63 | + { |
54 | 64 | it->value = value; |
55 | | - } else { |
56 | | - // Insert new key-value pair |
57 | | - bucket.push_back({key, value}); |
58 | 65 | } |
| 66 | + else bucket.push_back({key, value}); |
59 | 67 | } |
60 | 68 |
|
61 | 69 | // Retrieve the value associated with a key from the hash map |
62 | | - bool get(const Key& key, Value& value) { |
| 70 | + std::optional<Value> get(const Key &key) const |
| 71 | + { |
63 | 72 | std::shared_lock lock(getMutex(key)); |
64 | | - std::size_t hashValue = hashFunction(key); |
65 | | - std::size_t index = hashValue % buckets.size(); |
66 | | - |
67 | | - const auto& bucket = buckets[index]; |
68 | | - auto it = std::find_if(bucket.begin(), bucket.end(), [&](const Node& node) { return node.key == key; }); |
69 | | - |
70 | | - if (it != bucket.end()) { |
71 | | - value = it->value; |
72 | | - return true; // Found the key |
73 | | - } |
74 | | - |
75 | | - return false; // Key not found |
| 73 | + std::size_t index = hashFunction(key) % buckets.size(); |
| 74 | + const auto &bucket = buckets[index]; |
| 75 | + auto it = std::find_if(bucket.begin(), bucket.end(), |
| 76 | + [&](const Node & node) |
| 77 | + { |
| 78 | + return node.key == key; |
| 79 | + }); |
| 80 | + return (it != bucket.end()) ? std::optional<Value>(it->value) : std::nullopt; |
76 | 81 | } |
77 | 82 |
|
78 | 83 | // Remove a key-value pair from the hash map |
79 | | - void remove(const Key& key) { |
| 84 | + void remove(const Key &key) |
| 85 | + { |
80 | 86 | std::unique_lock lock(getMutex(key)); |
81 | | - std::size_t hashValue = hashFunction(key); |
82 | | - std::size_t index = hashValue % buckets.size(); |
83 | | - |
84 | | - auto& bucket = buckets[index]; |
85 | | - bucket.remove_if([&](const Node& node) { return node.key == key; }); |
| 87 | + std::size_t index = hashFunction(key) % buckets.size(); |
| 88 | + buckets[index].remove_if([&](const Node & node) |
| 89 | + { |
| 90 | + return node.key == key; |
| 91 | + }); |
86 | 92 | } |
87 | 93 |
|
88 | | - // Print the contents of the hash map |
89 | | - void print() { |
90 | | - for (std::size_t i = 0; i < buckets.size(); ++i) { |
| 94 | + // Print to any ostream |
| 95 | + void print(std::ostream &os = std::cout) const |
| 96 | + { |
| 97 | + for(std::size_t i = 0; i < buckets.size(); ++i) |
| 98 | + { |
91 | 99 | std::shared_lock lock(mutexes[i]); |
92 | | - std::cout << "Bucket " << i << ": "; |
93 | | - const auto& bucket = buckets[i]; |
94 | | - for (const auto& node : bucket) { |
95 | | - std::cout << "(" << node.key << ", " << node.value << ") "; |
| 100 | + os << "Bucket " << i << ": "; |
| 101 | + |
| 102 | + for(const auto &node : buckets[i]) |
| 103 | + { |
| 104 | + os << "(" << node.key << ", " << node.value << ") "; |
96 | 105 | } |
97 | | - std::cout << std::endl; |
| 106 | + |
| 107 | + os << "\n"; |
98 | 108 | } |
99 | 109 | } |
100 | 110 |
|
|
0 commit comments