Skip to content

Commit 7d0a5ce

Browse files
committed
Add concept of a key cache to the SciTokens verifier.
With this, we can bypass the remote lookup for frequently-used issuers.
1 parent cf053ee commit 7d0a5ce

File tree

4 files changed

+306
-15
lines changed

4 files changed

+306
-15
lines changed

CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g")
2020

2121
include (FindPkgConfig)
2222
pkg_check_modules(LIBCRYPTO REQUIRED libcrypto)
23-
set (INCLUDE_DIRECTORIES ${INCLUDE_DIRECTORIES} ${LIBCRYPTO_INCLUDE_DIRS})
23+
pkg_check_modules(SQLITE REQUIRED sqlite3)
2424

25-
include_directories( "${PROJECT_SOURCE_DIR}" ${JWT_CPP_INCLUDES} ${CURL_INCLUDES})
25+
include_directories( "${PROJECT_SOURCE_DIR}" ${JWT_CPP_INCLUDES} ${CURL_INCLUDES} ${LIBCRYPTO_INCLUDE_DIRS} ${SQLITE_INCLUDE_DIRS})
2626

27-
add_library(SciTokens SHARED src/scitokens.cpp src/scitokens_internal.cpp)
28-
target_link_libraries(SciTokens ${LIBCRYPTO_LIBRARIES} ${CURL_LIBRARIES})
27+
add_library(SciTokens SHARED src/scitokens.cpp src/scitokens_internal.cpp src/scitokens_cache.cpp)
28+
target_link_libraries(SciTokens ${LIBCRYPTO_LIBRARIES} ${CURL_LIBRARIES} ${SQLITE_LIBRARIES})
2929
set_target_properties(SciTokens PROPERTIES LINK_FLAGS "-Wl,--version-script=${PROJECT_SOURCE_DIR}/configs/export-symbols")
3030

3131
add_executable(scitokens-test src/test.cpp)

src/scitokens_cache.cpp

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
2+
#include <cstdint>
3+
#include <string>
4+
#include <vector>
5+
6+
#include <pwd.h>
7+
#include <stdlib.h>
8+
#include <unistd.h>
9+
#include <sys/stat.h>
10+
11+
#ifndef PICOJSON_USE_INT64
12+
#define PICOJSON_USE_INT64
13+
#endif
14+
#include <jwt-cpp/picojson.h>
15+
#include <sqlite3.h>
16+
17+
#include "scitokens_internal.h"
18+
19+
namespace {
20+
21+
void
22+
initialize_cachedb(const std::string &keycache_file) {
23+
24+
sqlite3 *db;
25+
int rc = sqlite3_open(keycache_file.c_str(), &db);
26+
if (rc != SQLITE_OK) {
27+
std::cerr << "SQLite key cache creation failed." << std::endl;
28+
sqlite3_close(db);
29+
return;
30+
}
31+
char *err_msg = nullptr;
32+
rc = sqlite3_exec(db, "CREATE TABLE IF NOT EXISTS keycache ("
33+
"issuer text UNIQUE PRIMARY KEY NOT NULL,"
34+
"keys text NOT NULL)",
35+
NULL, 0, &err_msg);
36+
if (rc) {
37+
std::cerr << "Sqlite table creation failed: " << err_msg << std::endl;
38+
sqlite3_free(err_msg);
39+
}
40+
sqlite3_close(db);
41+
}
42+
43+
/**
44+
* Get the Cache file location
45+
*
46+
* 1. $XDG_CACHE_HOME
47+
* 2. .cache subdirectory of home directory as returned by the password database
48+
*/
49+
std::string
50+
get_cache_file() {
51+
52+
const char *xdg_cache_home = getenv("XDG_CACHE_HOME");
53+
54+
auto bufsize = sysconf(_SC_GETPW_R_SIZE_MAX);
55+
bufsize = (bufsize == -1) ? 16384 : bufsize;
56+
57+
std::vector<char> buf;
58+
buf.reserve(bufsize);
59+
60+
std::string home_dir;
61+
struct passwd pwd, *result = NULL;
62+
getpwuid_r(geteuid(), &pwd, &buf[0], bufsize, &result);
63+
if (result && result->pw_dir) {
64+
home_dir = result->pw_dir;
65+
home_dir += "/.cache";
66+
}
67+
68+
std::string cache_dir(xdg_cache_home ? xdg_cache_home : home_dir.c_str());
69+
if (cache_dir.size() == 0) {
70+
return "";
71+
}
72+
73+
struct stat cache_dir_stat;
74+
if (-1 == stat(cache_dir.c_str(), &cache_dir_stat)) {
75+
if (errno == ENOENT) {
76+
if (-1 == mkdir(cache_dir.c_str(), 0700)) return "";
77+
}
78+
}
79+
80+
std::string keycache_dir = cache_dir + "/scitokens";
81+
if (-1 == stat(keycache_dir.c_str(), &cache_dir_stat)) {
82+
if (errno == ENOENT) {
83+
if (-1 == mkdir(keycache_dir.c_str(), 0700)) return "";
84+
}
85+
}
86+
87+
std::string keycache_file = keycache_dir + "/scitokens_cpp.sqllite";
88+
initialize_cachedb(keycache_file);
89+
90+
return keycache_file;
91+
}
92+
93+
94+
void
95+
remove_issuer_entry(sqlite3 *db, const std::string &issuer, bool new_transaction) {
96+
97+
if (new_transaction) sqlite3_exec(db, "BEGIN", 0, 0 , 0);
98+
99+
sqlite3_stmt *stmt;
100+
int rc = sqlite3_prepare_v2(db, "DELETE FROM keycache WHERE issuer = ?", -1, &stmt, NULL);
101+
if (rc != SQLITE_OK) {
102+
sqlite3_close(db);
103+
return;
104+
}
105+
106+
if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
107+
sqlite3_finalize(stmt);
108+
sqlite3_close(db);
109+
return;
110+
}
111+
112+
rc = sqlite3_step(stmt);
113+
if (rc != SQLITE_DONE) {
114+
sqlite3_finalize(stmt);
115+
sqlite3_close(db);
116+
return;
117+
}
118+
119+
sqlite3_finalize(stmt);
120+
121+
if (new_transaction) sqlite3_exec(db, "COMMIT", 0, 0 , 0);
122+
}
123+
124+
}
125+
126+
127+
bool
128+
scitokens::Validator::get_public_keys_from_db(const std::string issuer, int64_t now, picojson::value &keys, int64_t &next_update) {
129+
auto cache_fname = get_cache_file();
130+
if (cache_fname.size() == 0) {return false;}
131+
132+
sqlite3 *db;
133+
int rc = sqlite3_open(cache_fname.c_str(), &db);
134+
if (rc) {
135+
sqlite3_close(db);
136+
return false;
137+
}
138+
139+
sqlite3_stmt *stmt;
140+
rc = sqlite3_prepare_v2(db, "SELECT keys from keycache where issuer = ?", -1, &stmt, NULL);
141+
if (rc != SQLITE_OK) {
142+
sqlite3_close(db);
143+
return false;
144+
}
145+
146+
if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
147+
sqlite3_finalize(stmt);
148+
sqlite3_close(db);
149+
return false;
150+
}
151+
152+
rc = sqlite3_step(stmt);
153+
if (rc == SQLITE_ROW) {
154+
const unsigned char * data = sqlite3_column_text(stmt, 0);
155+
std::string metadata(reinterpret_cast<const char *>(data));
156+
sqlite3_finalize(stmt);
157+
picojson::value json_obj;
158+
auto err = picojson::parse(json_obj, metadata);
159+
if (!err.empty() || !json_obj.is<picojson::object>()) {
160+
remove_issuer_entry(db, issuer, true);
161+
sqlite3_close(db);
162+
return false;
163+
}
164+
auto top_obj = json_obj.get<picojson::object>();
165+
auto iter = top_obj.find("jwks");
166+
if (iter == top_obj.end() || !iter->second.is<picojson::object>()) {
167+
remove_issuer_entry(db, issuer, true);
168+
sqlite3_close(db);
169+
return false;
170+
}
171+
auto keys_local = iter->second;
172+
iter = top_obj.find("expires");
173+
if (iter == top_obj.end() || !iter->second.is<int64_t>()) {
174+
remove_issuer_entry(db, issuer, true);
175+
sqlite3_close(db);
176+
return false;
177+
}
178+
auto expiry = iter->second.get<int64_t>();
179+
if (now > expiry) {
180+
remove_issuer_entry(db, issuer, true);
181+
sqlite3_close(db);
182+
return false;
183+
}
184+
sqlite3_close(db);
185+
iter = top_obj.find("next_update");
186+
if (iter == top_obj.end() || !iter->second.is<int64_t>()) {
187+
next_update = expiry - 4*3600;
188+
} else {
189+
next_update = iter->second.get<int64_t>();
190+
}
191+
keys = keys_local;
192+
return true;
193+
} else if (rc == SQLITE_DONE) {
194+
sqlite3_finalize(stmt);
195+
sqlite3_close(db);
196+
return false;
197+
} else {
198+
// TODO: log error?
199+
sqlite3_finalize(stmt);
200+
sqlite3_close(db);
201+
return false;
202+
}
203+
}
204+
205+
206+
bool
207+
scitokens::Validator::store_public_keys(const std::string &issuer, const picojson::value &keys, int64_t next_update, int64_t expires) {
208+
picojson::object top_obj;
209+
top_obj["jwks"] = keys;
210+
top_obj["next_update"] = picojson::value(next_update);
211+
top_obj["expires"] = picojson::value(expires);
212+
picojson::value db_value(top_obj);
213+
std::string db_str = db_value.serialize();
214+
215+
auto cache_fname = get_cache_file();
216+
if (cache_fname.size() == 0) {return false;}
217+
218+
sqlite3 *db;
219+
int rc = sqlite3_open(cache_fname.c_str(), &db);
220+
if (rc) {
221+
sqlite3_close(db);
222+
return false;
223+
}
224+
225+
sqlite3_exec(db, "BEGIN", 0, 0 , 0);
226+
227+
remove_issuer_entry(db, issuer, false);
228+
229+
sqlite3_stmt *stmt;
230+
rc = sqlite3_prepare_v2(db, "INSERT INTO keycache VALUES (?, ?)", -1, &stmt, NULL);
231+
if (rc != SQLITE_OK) {
232+
sqlite3_close(db);
233+
return false;
234+
}
235+
236+
if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
237+
sqlite3_finalize(stmt);
238+
sqlite3_close(db);
239+
return false;
240+
}
241+
242+
if (sqlite3_bind_text(stmt, 2, db_str.c_str(), db_str.size(), SQLITE_STATIC) != SQLITE_OK) {
243+
sqlite3_finalize(stmt);
244+
sqlite3_close(db);
245+
return false;
246+
}
247+
248+
rc = sqlite3_step(stmt);
249+
if (rc != SQLITE_DONE) {
250+
sqlite3_finalize(stmt);
251+
sqlite3_close(db);
252+
return false;
253+
}
254+
255+
sqlite3_exec(db, "COMMIT", 0, 0 , 0);
256+
257+
return true;
258+
}

src/scitokens_internal.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ SciToken::deserialize(const std::string &data) {
271271

272272

273273
void
274-
Validator::get_public_key_pem(const std::string &issuer, const std::string &kid, std::string &public_pem, std::string &algorithm)
274+
Validator::get_public_keys_from_web(const std::string &issuer, picojson::value &keys, int64_t &next_update, int64_t &expires)
275275
{
276276
std::string openid_metadata, oauth_metadata;
277277
get_metadata_endpoint(issuer, openid_metadata, oauth_metadata);
@@ -316,45 +316,75 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid,
316316
throw JsonException(err);
317317
}
318318

319-
auto key_obj = find_key_id(json_obj, kid);
319+
auto now = std::time(NULL);
320+
// TODO: take expiration time from the cache-control header in the response.
320321

321-
iter = key_obj.find("alg");
322+
keys = json_obj;
323+
324+
next_update = now + 600;
325+
expires = now + 4*3600;
326+
}
327+
328+
void
329+
Validator::get_public_key_pem(const std::string &issuer, const std::string &kid, std::string &public_pem, std::string &algorithm) {
330+
331+
picojson::value keys;
332+
int64_t next_update, expires;
333+
auto now = std::time(NULL);
334+
if (get_public_keys_from_db(issuer, now, keys, next_update)) {
335+
if (now > next_update) {
336+
try {
337+
get_public_keys_from_web(issuer, keys, next_update, expires);
338+
store_public_keys(issuer, keys, next_update, expires);
339+
} catch (std::runtime_error &) {
340+
// ignore the exception: we have a valid set of keys already/
341+
}
342+
}
343+
} else {
344+
get_public_keys_from_web(issuer, keys, next_update, expires);
345+
store_public_keys(issuer, keys, next_update, expires);
346+
}
347+
348+
auto key_obj = find_key_id(keys, kid);
349+
350+
auto iter = key_obj.find("alg");
322351
if (iter == key_obj.end() || (!iter->second.is<std::string>())) {
323352
throw JsonException("Key is missing algorithm name");
324-
}
353+
}
325354
auto alg = iter->second.get<std::string>();
326355
if (alg != "RS256" and alg != "ES256") {
327356
throw UnsupportedKeyException("Issuer is using an unsupported algorithm");
328-
}
357+
}
329358
std::string pem;
330359

331360
if (alg == "ES256")
332361
{
333362
iter = key_obj.find("x");
334363
if (iter == key_obj.end() || (!iter->second.is<std::string>())) {
335364
throw JsonException("Elliptic curve is missing x-coordinate");
336-
}
365+
}
337366
auto x = iter->second.get<std::string>();
338367
iter = key_obj.find("y");
339368
if (iter == key_obj.end() || (!iter->second.is<std::string>())) {
340369
throw JsonException("Elliptic curve is missing y-coordinate");
341-
}
370+
}
342371
auto y = iter->second.get<std::string>();
343372
pem = es256_from_coords(x, y);
344373
} else {
345374
iter = key_obj.find("e");
346375
if (iter == key_obj.end() || (!iter->second.is<std::string>())) {
347376
throw JsonException("Public key is missing exponent");
348-
}
377+
}
349378
auto e = iter->second.get<std::string>();
350379
iter = key_obj.find("n");
351380
if (iter == key_obj.end() || (!iter->second.is<std::string>())) {
352381
throw JsonException("Public key is missing n-value");
353-
}
382+
}
354383
auto n = iter->second.get<std::string>();
355384
pem = rs256_from_coords(e, n);
356-
}
357-
385+
}
386+
358387
public_pem = pem;
359388
algorithm = alg;
360389
}
390+

src/scitokens_internal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ class Validator {
176176

177177
private:
178178
void get_public_key_pem(const std::string &issuer, const std::string &kid, std::string &public_pem, std::string &algorithm);
179+
void get_public_keys_from_web(const std::string &issuer, picojson::value &keys, int64_t &next_update, int64_t &expires);
180+
bool get_public_keys_from_db(const std::string issuer, int64_t now, picojson::value &keys, int64_t &next_update);
181+
bool store_public_keys(const std::string &issuer, const picojson::value &keys, int64_t next_update, int64_t expires);
179182
};
180183

181184
}

0 commit comments

Comments
 (0)