Skip to content

Commit 7858357

Browse files
committed
Adding deserialize async functions - checkpoint
1 parent 471c2b1 commit 7858357

File tree

5 files changed

+197
-57
lines changed

5 files changed

+197
-57
lines changed

src/scitokens.cpp

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,68 @@ int scitoken_deserialize_v2(const char *value, SciToken token, char const* const
255255
return 0;
256256
}
257257

258+
int scitoken_deserialize_start(const char *value, SciToken *token, char const* const* allowed_issuers, SciTokenStatus *status_out, char **err_msg) {
259+
if (value == nullptr) {
260+
if (err_msg) {*err_msg = strdup("Token may not be NULL");}
261+
return -1;
262+
}
263+
if (token == nullptr) {
264+
if (err_msg) {*err_msg = strdup("Output token not provided");}
265+
return -1;
266+
}
267+
268+
scitokens::SciTokenKey key;
269+
scitokens::SciToken *real_token = new scitokens::SciToken(key);
270+
271+
272+
std::vector<std::string> allowed_issuers_vec;
273+
if (allowed_issuers != nullptr) {
274+
for (int idx=0; allowed_issuers[idx]; idx++) {
275+
allowed_issuers_vec.push_back(allowed_issuers[idx]);
276+
}
277+
}
278+
279+
std::unique_ptr<scitokens::SciTokenAsyncStatus> status;
280+
try {
281+
status = real_token->deserialize_start(value, allowed_issuers_vec);
282+
} catch (std::exception &exc) {
283+
if (err_msg) {
284+
*err_msg = strdup(exc.what());
285+
}
286+
delete real_token;
287+
return -1;
288+
}
289+
*token = real_token;
290+
*status_out = status.release();
291+
return 0;
292+
}
293+
294+
int scitoken_deserialize_continue(SciToken token, SciTokenStatus *status, char **err_msg) {
295+
if (token == nullptr) {
296+
if (err_msg) {*err_msg = strdup("Output token not provided");}
297+
return -1;
298+
}
299+
scitokens::SciToken *real_token = reinterpret_cast<scitokens::SciToken*>(token);
300+
std::unique_ptr<scitokens::SciTokenAsyncStatus> real_status(reinterpret_cast<scitokens::SciTokenAsyncStatus*>(*status));
301+
302+
try {
303+
real_status = real_token->deserialize_continue(std::move(real_status));
304+
} catch (std::exception &exc) {
305+
*status = nullptr;
306+
if (err_msg) {
307+
*err_msg = strdup(exc.what());
308+
}
309+
return -1;
310+
}
311+
312+
if (real_status->m_status->m_done) {
313+
*status = nullptr;
314+
} else {
315+
*status = real_status.release();
316+
}
317+
return 0;
318+
}
319+
258320
int scitoken_store_public_ec_key(const char *issuer, const char *keyid, const char *key, char **err_msg)
259321
{
260322
bool success;
@@ -458,7 +520,7 @@ int enforcer_generate_acls_start(const Enforcer enf, const SciToken scitoken,
458520
auto real_scitoken = reinterpret_cast<scitokens::SciToken*>(scitoken);
459521

460522
scitokens::Enforcer::AclsList acls_list;
461-
std::unique_ptr<scitokens::Validator::AsyncStatus> status;
523+
std::unique_ptr<scitokens::AsyncStatus> status;
462524
try {
463525
status = real_enf->generate_acls_start(*real_scitoken, acls_list);
464526
} catch (std::exception &exc) {
@@ -491,8 +553,8 @@ int enforcer_generate_acls_continue(const Enforcer enf, SciTokenStatus *status,
491553
}
492554

493555
scitokens::Enforcer::AclsList acls_list;
494-
std::unique_ptr<scitokens::Validator::AsyncStatus> status_internal(
495-
reinterpret_cast<scitokens::Validator::AsyncStatus*>(*status));
556+
std::unique_ptr<scitokens::AsyncStatus> status_internal(
557+
reinterpret_cast<scitokens::AsyncStatus*>(*status));
496558
try {
497559
status_internal = real_enf->generate_acls_continue(std::move(status_internal), acls_list);
498560
} catch (std::exception &exc) {
@@ -539,8 +601,8 @@ int enforcer_test(const Enforcer enf, const SciToken scitoken, const Acl *acl, c
539601

540602

541603
void scitoken_status_free(SciTokenStatus status) {
542-
std::unique_ptr<scitokens::Validator::AsyncStatus> status_real(
543-
reinterpret_cast<scitokens::Validator::AsyncStatus*>(status));
604+
std::unique_ptr<scitokens::AsyncStatus> status_real(
605+
reinterpret_cast<scitokens::AsyncStatus*>(status));
544606
}
545607

546608

@@ -555,7 +617,7 @@ int scitoken_status_get_timeout_val(const SciTokenStatus *status, time_t expiry_
555617
return -1;
556618
}
557619

558-
auto real_status = reinterpret_cast<const scitokens::Validator::AsyncStatus*>(status);
620+
auto real_status = reinterpret_cast<const scitokens::AsyncStatus*>(status);
559621
struct timeval timeout_internal = real_status->get_timeout_val(expiry_time);
560622
timeout->tv_sec = timeout_internal.tv_sec;
561623
timeout->tv_usec = timeout_internal.tv_usec;
@@ -575,7 +637,7 @@ int scitoken_status_get_read_fd_set(SciTokenStatus *status, fd_set **read_fd_set
575637
return -1;
576638
}
577639

578-
auto real_status = reinterpret_cast<scitokens::Validator::AsyncStatus*>(status);
640+
auto real_status = reinterpret_cast<scitokens::AsyncStatus*>(status);
579641
*read_fd_set = real_status->get_read_fd_set();
580642
return 0;
581643
}
@@ -592,7 +654,7 @@ int scitoken_status_get_write_fd_set(SciTokenStatus *status, fd_set **write_fd_s
592654
return -1;
593655
}
594656

595-
auto real_status = reinterpret_cast<scitokens::Validator::AsyncStatus*>(status);
657+
auto real_status = reinterpret_cast<scitokens::AsyncStatus*>(status);
596658
*write_fd_set = real_status->get_write_fd_set();
597659
return 0;
598660
}
@@ -609,7 +671,7 @@ int scitoken_status_get_exc_fd_set(SciTokenStatus *status, fd_set **exc_fd_set,
609671
return -1;
610672
}
611673

612-
auto real_status = reinterpret_cast<scitokens::Validator::AsyncStatus*>(status);
674+
auto real_status = reinterpret_cast<scitokens::AsyncStatus*>(status);
613675
*exc_fd_set = real_status->get_exc_fd_set();
614676
return 0;
615677
}
@@ -626,7 +688,7 @@ int scitoken_status_get_max_fd(const SciTokenStatus *status, int *max_fd, char *
626688
return -1;
627689
}
628690

629-
auto real_status = reinterpret_cast<const scitokens::Validator::AsyncStatus*>(status);
691+
auto real_status = reinterpret_cast<const scitokens::AsyncStatus*>(status);
630692
*max_fd = real_status->get_max_fd();
631693
return 0;
632694
}

src/scitokens.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,23 @@ void scitoken_set_deserialize_profile(SciToken token, SciTokenProfile profile);
9292

9393
int scitoken_deserialize(const char *value, SciToken *token, char const* const* allowed_issuers, char **err_msg);
9494

95+
/**
96+
* @brief Start the deserialization process for a token, returning a status object.
97+
*
98+
* @param value The serialized token.
99+
* @param token Destination for the token object.
100+
* @param allowed_issuers List of allowed issuers, or nullptr for no issuer check.
101+
* @param status Destination for the status object.
102+
* @param err_msg Destination for error message.
103+
* @return int 0 on success, -1 on error.
104+
*/
105+
106+
int scitoken_deserialize_start(const char *value, SciToken *token, char const* const* allowed_issuers,
107+
SciTokenStatus *status, char **err_msg);
108+
109+
110+
int scitoken_deserialize_continue(SciToken *token, SciTokenStatus *status, char **err_msg);
111+
95112
int scitoken_deserialize_v2(const char *value, SciToken token, char const* const* allowed_issuers, char **err_msg);
96113

97114
int scitoken_store_public_ec_key(const char *issuer, const char *keyid, const char *value, char **err_msg);

src/scitokens_internal.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,37 @@ SciToken::deserialize(const std::string &data, const std::vector<std::string> al
447447
m_profile = val.get_profile();
448448
}
449449

450+
std::unique_ptr<SciTokenAsyncStatus>
451+
SciToken::deserialize_start(const std::string &data, const std::vector<std::string> allowed_issuers) {
452+
m_decoded.reset(new jwt::decoded_jwt(data));
453+
454+
scitokens::Validator val;
455+
val.add_allowed_issuers(allowed_issuers);
456+
val.set_validate_all_claims_scitokens_1(false);
457+
val.set_validate_profile(m_deserialize_profile);
458+
std::unique_ptr<SciTokenAsyncStatus> status(new SciTokenAsyncStatus());
459+
status->m_status = val.verify_async(*m_decoded);
460+
461+
return std::move(deserialize_continue(std::move(status)));
462+
}
463+
464+
std::unique_ptr<SciTokenAsyncStatus>
465+
SciToken::deserialize_continue(std::unique_ptr<SciTokenAsyncStatus> status) {
466+
467+
status->m_status = status->m_validator->verify_async_continue(std::move(status->m_status));
468+
// Check if the status is completed
469+
if (status->m_status) {
470+
// Set all the claims
471+
m_claims = m_decoded->get_payload_claims();
472+
473+
// Copy over the profile
474+
m_profile = m_profile;
475+
}
476+
return std::move(status);
477+
}
478+
450479

451-
std::unique_ptr<Validator::AsyncStatus>
480+
std::unique_ptr<AsyncStatus>
452481
Validator::get_public_keys_from_web(const std::string &issuer)
453482
{
454483
std::string openid_metadata, oauth_metadata;
@@ -466,7 +495,7 @@ Validator::get_public_keys_from_web(const std::string &issuer)
466495
}
467496

468497

469-
std::unique_ptr<Validator::AsyncStatus>
498+
std::unique_ptr<AsyncStatus>
470499
Validator::get_public_keys_from_web_continue(std::unique_ptr<AsyncStatus> status)
471500
{
472501
char *buffer;
@@ -537,7 +566,7 @@ Validator::get_public_keys_from_web_continue(std::unique_ptr<AsyncStatus> status
537566
}
538567

539568

540-
std::unique_ptr<Validator::AsyncStatus>
569+
std::unique_ptr<AsyncStatus>
541570
Validator::get_public_key_pem(const std::string &issuer, const std::string &kid, std::string &public_pem, std::string &algorithm) {
542571

543572
auto now = std::time(NULL);
@@ -566,8 +595,8 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid,
566595
}
567596
}
568597

569-
std::unique_ptr<Validator::AsyncStatus>
570-
Validator::get_public_key_pem_continue(std::unique_ptr<Validator::AsyncStatus> status, std::string &public_pem, std::string &algorithm) {
598+
std::unique_ptr<AsyncStatus>
599+
Validator::get_public_key_pem_continue(std::unique_ptr<AsyncStatus> status, std::string &public_pem, std::string &algorithm) {
571600

572601
if (status->m_continue_fetch) {
573602
try {

src/scitokens_internal.h

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,55 @@ class SciTokenKey {
157157

158158
class Validator;
159159

160+
class AsyncStatus {
161+
public:
162+
AsyncStatus() = default;
163+
AsyncStatus(const AsyncStatus &) = delete;
164+
AsyncStatus & operator=(const AsyncStatus &) = delete;
165+
166+
bool m_done{false};
167+
bool m_continue_fetch{false};
168+
bool m_ignore_error{false};
169+
bool m_do_store{true};
170+
bool m_has_metadata{false};
171+
bool m_oauth_fallback{false};
172+
173+
int64_t m_next_update{-1};
174+
int64_t m_expires{-1};
175+
picojson::value m_keys;
176+
std::string m_issuer;
177+
std::string m_kid;
178+
std::string m_oauth_metadata_url;
179+
std::unique_ptr<internal::SimpleCurlGet> m_cget;
180+
std::string m_jwt_string;
181+
std::string m_public_pem;
182+
std::string m_algorithm;
183+
184+
struct timeval get_timeout_val(time_t expiry_time) const {
185+
auto now = time(NULL);
186+
long timeout_ms = 100*(expiry_time - now);
187+
if (m_cget && (m_cget->get_timeout_ms() < timeout_ms)) timeout_ms = m_cget->get_timeout_ms();
188+
struct timeval timeout;
189+
timeout.tv_sec = timeout_ms / 1000;
190+
timeout.tv_usec = (timeout_ms % 1000) * 1000;
191+
return timeout;
192+
}
193+
194+
int get_max_fd() const {return m_cget ? m_cget->get_max_fd() : -1;}
195+
fd_set *get_read_fd_set() {return m_cget ? m_cget->get_read_fd_set() : nullptr;}
196+
fd_set *get_write_fd_set() {return m_cget ? m_cget->get_write_fd_set() : nullptr;}
197+
fd_set *get_exc_fd_set() {return m_cget ? m_cget->get_exc_fd_set() : nullptr;}
198+
};
199+
200+
class SciTokenAsyncStatus {
201+
public:
202+
SciTokenAsyncStatus() = default;
203+
SciTokenAsyncStatus(const SciTokenAsyncStatus &) = delete;
204+
SciTokenAsyncStatus & operator=(const SciTokenAsyncStatus &) = delete;
205+
206+
std::unique_ptr<Validator> m_validator;
207+
std::unique_ptr<AsyncStatus> m_status;
208+
};
160209

161210
class SciToken {
162211

@@ -286,6 +335,10 @@ friend class scitokens::Validator;
286335
void
287336
deserialize(const std::string &data, std::vector<std::string> allowed_issuers={});
288337

338+
std::unique_ptr<SciTokenAsyncStatus> deserialize_start(const std::string &data, std::vector<std::string> allowed_issuers={});
339+
340+
std::unique_ptr<SciTokenAsyncStatus> deserialize_continue(std::unique_ptr<SciTokenAsyncStatus> status);
341+
289342
private:
290343
bool m_issuer_set{false};
291344
int m_lifetime{600};
@@ -306,46 +359,6 @@ class Validator {
306359

307360
public:
308361

309-
class AsyncStatus {
310-
public:
311-
AsyncStatus() = default;
312-
AsyncStatus(const AsyncStatus &) = delete;
313-
AsyncStatus & operator=(const AsyncStatus &) = delete;
314-
315-
bool m_done{false};
316-
bool m_continue_fetch{false};
317-
bool m_ignore_error{false};
318-
bool m_do_store{true};
319-
bool m_has_metadata{false};
320-
bool m_oauth_fallback{false};
321-
322-
int64_t m_next_update{-1};
323-
int64_t m_expires{-1};
324-
picojson::value m_keys;
325-
std::string m_issuer;
326-
std::string m_kid;
327-
std::string m_oauth_metadata_url;
328-
std::unique_ptr<internal::SimpleCurlGet> m_cget;
329-
std::string m_jwt_string;
330-
std::string m_public_pem;
331-
std::string m_algorithm;
332-
333-
struct timeval get_timeout_val(time_t expiry_time) const {
334-
auto now = time(NULL);
335-
long timeout_ms = 100*(expiry_time - now);
336-
if (m_cget && (m_cget->get_timeout_ms() < timeout_ms)) timeout_ms = m_cget->get_timeout_ms();
337-
struct timeval timeout;
338-
timeout.tv_sec = timeout_ms / 1000;
339-
timeout.tv_usec = (timeout_ms % 1000) * 1000;
340-
return timeout;
341-
}
342-
343-
int get_max_fd() const {return m_cget ? m_cget->get_max_fd() : -1;}
344-
fd_set *get_read_fd_set() {return m_cget ? m_cget->get_read_fd_set() : nullptr;}
345-
fd_set *get_write_fd_set() {return m_cget ? m_cget->get_write_fd_set() : nullptr;}
346-
fd_set *get_exc_fd_set() {return m_cget ? m_cget->get_exc_fd_set() : nullptr;}
347-
};
348-
349362
std::unique_ptr<AsyncStatus> verify_async(const SciToken &scitoken) {
350363
const jwt::decoded_jwt *jwt_decoded = scitoken.m_decoded.get();
351364
if (!jwt_decoded) {
@@ -701,7 +714,7 @@ class Enforcer {
701714
}
702715

703716

704-
std::unique_ptr<Validator::AsyncStatus> generate_acls_start(const SciToken &scitoken, AclsList &acls) {
717+
std::unique_ptr<AsyncStatus> generate_acls_start(const SciToken &scitoken, AclsList &acls) {
705718
reset_state();
706719
auto status = m_validator.verify_async(scitoken);
707720
if (status->m_done) {
@@ -711,7 +724,7 @@ class Enforcer {
711724
}
712725

713726

714-
std::unique_ptr<Validator::AsyncStatus> generate_acls_continue(std::unique_ptr<Validator::AsyncStatus> status, AclsList &acls) {
727+
std::unique_ptr<AsyncStatus> generate_acls_continue(std::unique_ptr<AsyncStatus> status, AclsList &acls) {
715728
auto result = m_validator.verify_async_continue(std::move(status));
716729
if (result->m_done) {
717730
acls = m_gen_acls;

test/main.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,28 @@ TEST_F(SerializeTest, EnforcerScopeTest) {
344344
ASSERT_TRUE(found_read);
345345
ASSERT_TRUE(found_write);
346346

347+
}
348+
349+
TEST_F(SerializeTest, DeserializeAsyncTest) {
350+
char *err_msg = nullptr;
351+
352+
// Serialize as "compat" token.
353+
char *token_value = nullptr;
354+
scitoken_set_serialize_profile(m_token.get(), SciTokenProfile::COMPAT);
355+
auto rv = scitoken_serialize(m_token.get(), &token_value, &err_msg);
356+
ASSERT_TRUE(rv == 0);
357+
std::unique_ptr<char, decltype(&free)> token_value_ptr(token_value, free);
347358

359+
SciToken scitoken;
360+
SciTokenStatus status;
348361

362+
// Accepts any profile.
363+
rv = scitoken_deserialize_start(token_value, &scitoken, nullptr, &status, &err_msg);
364+
ASSERT_TRUE(rv == 0);
349365

366+
// Accepts only an at+jwt token, should fail with COMPAT token
367+
rv = scitoken_deserialize_continue(&scitoken, &status, &err_msg);
368+
ASSERT_FALSE(rv == 0);
350369
}
351370

352371
}

0 commit comments

Comments
 (0)