@@ -189,10 +189,28 @@ struct local_base64url : public jwt::alphabet::base64url {
189189};
190190
191191
192+ // Assuming a padding, decode
193+ std::string b64url_decode_nopadding (const std::string &input)
194+ {
195+ std::string result = input;
196+ switch (result.size () % 4 ) {
197+ case 1 :
198+ result += " =" ; // fallthrough
199+ case 2 :
200+ result += " =" ; // fallthrough
201+ case 3 :
202+ result += " =" ; // fallthrough
203+ default :
204+ break ;
205+ }
206+ return jwt::base::decode<local_base64url>(result);
207+ }
208+
209+
192210std::string
193211es256_from_coords (const std::string &x_str, const std::string &y_str) {
194- auto x_decode = jwt::base::decode<local_base64url> (x_str);
195- auto y_decode = jwt::base::decode<local_base64url> (y_str);
212+ auto x_decode = b64url_decode_nopadding (x_str);
213+ auto y_decode = b64url_decode_nopadding (y_str);
196214
197215 std::unique_ptr<EC_KEY, decltype (&EC_KEY_free)> ec (EC_KEY_new_by_curve_name (NID_X9_62_prime256v1), EC_KEY_free);
198216 if (!ec.get ()) {
@@ -232,8 +250,8 @@ es256_from_coords(const std::string &x_str, const std::string &y_str) {
232250
233251std::string
234252rs256_from_coords (const std::string &e_str, const std::string &n_str) {
235- auto e_decode = jwt::base::decode<local_base64url> (e_str);
236- auto n_decode = jwt::base::decode<local_base64url> (n_str);
253+ auto e_decode = b64url_decode_nopadding (e_str);
254+ auto n_decode = b64url_decode_nopadding (n_str);
237255 std::unique_ptr<BIGNUM, decltype (&BN_free)> e_bignum (BN_bin2bn (reinterpret_cast <const unsigned char *>(e_decode.c_str ()), e_decode.size (), nullptr ), BN_free);
238256 std::unique_ptr<BIGNUM, decltype (&BN_free)> n_bignum (BN_bin2bn (reinterpret_cast <const unsigned char *>(n_decode.c_str ()), n_decode.size (), nullptr ), BN_free);
239257
@@ -399,10 +417,33 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid,
399417 auto key_obj = find_key_id (keys, kid);
400418
401419 auto iter = key_obj.find (" alg" );
420+ std::string alg;
402421 if (iter == key_obj.end () || (!iter->second .is <std::string>())) {
403- throw JsonException (" Key is missing algorithm name" );
404- }
405- auto alg = iter->second .get <std::string>();
422+ auto iter2 = key_obj.find (" kty" );
423+ if (iter2 == key_obj.end () || !iter2->second .is <std::string>()) {
424+ throw JsonException (" Key is missing key type" );
425+ } else {
426+ auto kty = iter2->second .get <std::string>();
427+ if (kty == " RSA" ) {
428+ alg = " RS256" ;
429+ } else if (kty == " EC" ) {
430+ auto iter3 = key_obj.find (" crv" );
431+ if (iter3 == key_obj.end () || !iter3->second .is <std::string>()) {
432+ throw JsonException (" EC key is missing curve name" );
433+ }
434+ auto crv = iter2->second .get <std::string>();
435+ if (crv == " P-256" ) {
436+ alg = " EC256" ;
437+ } else {
438+ throw JsonException (" Unsupported EC curve in public key" );
439+ }
440+ } else {
441+ throw JsonException (" Unknown public key type" );
442+ }
443+ }
444+ } else {
445+ alg = iter->second .get <std::string>();
446+ }
406447 if (alg != " RS256" and alg != " ES256" ) {
407448 throw UnsupportedKeyException (" Issuer is using an unsupported algorithm" );
408449 }
0 commit comments