Skip to content

Commit 4b6c1db

Browse files
authored
Unify registrations for local passwords and upstream OAuth registrations (#5281)
2 parents bbad15a + 61ee8da commit 4b6c1db

15 files changed

+553
-130
lines changed

crates/data-model/src/users.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ pub struct UserRegistration {
272272
pub email_authentication_id: Option<Ulid>,
273273
pub user_registration_token_id: Option<Ulid>,
274274
pub password: Option<UserRegistrationPassword>,
275+
pub upstream_oauth_authorization_session_id: Option<Ulid>,
275276
pub post_auth_action: Option<serde_json::Value>,
276277
pub ip_address: Option<IpAddr>,
277278
pub user_agent: Option<String>,

crates/handlers/src/graphql/mutations/user_email.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ impl UserEmailMutations {
817817

818818
let authentication = repo
819819
.user_email()
820-
.complete_authentication(&clock, authentication, &code)
820+
.complete_authentication_with_code(&clock, authentication, &code)
821821
.await?;
822822

823823
// Check the email is not already in use by anyone, including the current user

crates/handlers/src/upstream_oauth2/link.rs

Lines changed: 138 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ use mas_policy::Policy;
2626
use mas_router::UrlBuilder;
2727
use mas_storage::{
2828
BoxRepository, RepositoryAccess,
29-
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
3029
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
3130
user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
3231
};
@@ -46,7 +45,7 @@ use super::{
4645
};
4746
use crate::{
4847
BoundActivityTracker, METER, PreferredLanguage, SiteConfig, impl_from_error_for_route,
49-
views::shared::OptionalPostAuthAction,
48+
views::{register::UserRegistrationSessionsCookie, shared::OptionalPostAuthAction},
5049
};
5150

5251
static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
@@ -610,10 +609,6 @@ pub(crate) async fn post(
610609
.lookup_link(link_id)
611610
.map_err(|_| RouteError::MissingCookie)?;
612611

613-
let post_auth_action = OptionalPostAuthAction {
614-
post_auth_action: post_auth_action.cloned(),
615-
};
616-
617612
let link = repo
618613
.upstream_oauth_link()
619614
.lookup(link_id)
@@ -641,15 +636,35 @@ pub(crate) async fn post(
641636
let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
642637
let form_state = form.to_form_state();
643638

644-
let session = match (maybe_user_session, link.user_id, form) {
639+
match (maybe_user_session, link.user_id, form) {
645640
(Some(session), None, FormData::Link) => {
646641
// The user is already logged in, the link is not linked to any user, and the
647642
// user asked to link their account.
648643
repo.upstream_oauth_link()
649644
.associate_to_user(&link, &session.user)
650645
.await?;
651646

652-
session
647+
let upstream_session = repo
648+
.upstream_oauth_session()
649+
.consume(&clock, upstream_session)
650+
.await?;
651+
652+
repo.browser_session()
653+
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
654+
.await?;
655+
656+
let post_auth_action = OptionalPostAuthAction {
657+
post_auth_action: post_auth_action.cloned(),
658+
};
659+
660+
let cookie_jar = sessions_cookie
661+
.consume_link(link_id)?
662+
.save(cookie_jar, &clock);
663+
let cookie_jar = cookie_jar.set_session(&session);
664+
665+
repo.save().await?;
666+
667+
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
653668
}
654669

655670
(None, None, FormData::Link) => {
@@ -714,14 +729,38 @@ pub(crate) async fn post(
714729
return Err(RouteError::InvalidFormAction);
715730
}
716731
UpstreamOAuthProviderOnConflict::Add => {
717-
//add link to the user
732+
// Add link to the user
718733
repo.upstream_oauth_link()
719734
.associate_to_user(&link, &user)
720735
.await?;
721736

722-
repo.browser_session()
737+
// And sign in the user
738+
let session = repo
739+
.browser_session()
723740
.add(&mut rng, &clock, &user, user_agent)
724-
.await?
741+
.await?;
742+
743+
let upstream_session = repo
744+
.upstream_oauth_session()
745+
.consume(&clock, upstream_session)
746+
.await?;
747+
748+
repo.browser_session()
749+
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
750+
.await?;
751+
752+
let post_auth_action = OptionalPostAuthAction {
753+
post_auth_action: post_auth_action.cloned(),
754+
};
755+
756+
let cookie_jar = sessions_cookie
757+
.consume_link(link_id)?
758+
.save(cookie_jar, &clock);
759+
let cookie_jar = cookie_jar.set_session(&session);
760+
761+
repo.save().await?;
762+
763+
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
725764
}
726765
}
727766
}
@@ -950,61 +989,84 @@ pub(crate) async fn post(
950989

951990
REGISTRATION_COUNTER.add(1, &[KeyValue::new(PROVIDER, provider.id.to_string())]);
952991

953-
// Now we can create the user
954-
let user = repo.user().add(&mut rng, &clock, username).await?;
992+
let mut registration = repo
993+
.user_registration()
994+
.add(
995+
&mut rng,
996+
&clock,
997+
username,
998+
activity_tracker.ip(),
999+
user_agent,
1000+
post_auth_action.map(|action| serde_json::json!(action)),
1001+
)
1002+
.await?;
9551003

9561004
if let Some(terms_url) = &site_config.tos_uri {
957-
repo.user_terms()
958-
.accept_terms(&mut rng, &clock, &user, terms_url.clone())
1005+
registration = repo
1006+
.user_registration()
1007+
.set_terms_url(registration, terms_url.clone())
9591008
.await?;
9601009
}
9611010

962-
// And schedule the job to provision it
963-
let mut job = ProvisionUserJob::new(&user);
1011+
// If we have an email, add an email authentication and complete it
1012+
if let Some(email) = email {
1013+
let authentication = repo
1014+
.user_email()
1015+
.add_authentication_for_registration(&mut rng, &clock, email, &registration)
1016+
.await?;
1017+
let authentication = repo
1018+
.user_email()
1019+
.complete_authentication_with_upstream(
1020+
&clock,
1021+
authentication,
1022+
&upstream_session,
1023+
)
1024+
.await?;
9641025

965-
// If we have a display name, set it during provisioning
966-
if let Some(name) = display_name {
967-
job = job.set_display_name(name);
1026+
registration = repo
1027+
.user_registration()
1028+
.set_email_authentication(registration, &authentication)
1029+
.await?;
9681030
}
9691031

970-
repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
971-
972-
// If we have an email, add it to the user
973-
if let Some(email) = email {
974-
repo.user_email()
975-
.add(&mut rng, &clock, &user, email)
1032+
// If we have a display name, add it to the registration
1033+
if let Some(name) = display_name {
1034+
registration = repo
1035+
.user_registration()
1036+
.set_display_name(registration, name)
9761037
.await?;
9771038
}
9781039

979-
repo.upstream_oauth_link()
980-
.associate_to_user(&link, &user)
1040+
let registration = repo
1041+
.user_registration()
1042+
.set_upstream_oauth_authorization_session(registration, &upstream_session)
9811043
.await?;
9821044

983-
repo.browser_session()
984-
.add(&mut rng, &clock, &user, user_agent)
985-
.await?
986-
}
1045+
repo.upstream_oauth_session()
1046+
.consume(&clock, upstream_session)
1047+
.await?;
9871048

988-
_ => return Err(RouteError::InvalidFormAction),
989-
};
1049+
let registrations = UserRegistrationSessionsCookie::load(&cookie_jar);
9901050

991-
let upstream_session = repo
992-
.upstream_oauth_session()
993-
.consume(&clock, upstream_session)
994-
.await?;
1051+
let cookie_jar = sessions_cookie
1052+
.consume_link(link_id)?
1053+
.save(cookie_jar, &clock);
9951054

996-
repo.browser_session()
997-
.authenticate_with_upstream(&mut rng, &clock, &session, &upstream_session)
998-
.await?;
1055+
let cookie_jar = registrations.add(&registration).save(cookie_jar, &clock);
9991056

1000-
let cookie_jar = sessions_cookie
1001-
.consume_link(link_id)?
1002-
.save(cookie_jar, &clock);
1003-
let cookie_jar = cookie_jar.set_session(&session);
1057+
repo.save().await?;
10041058

1005-
repo.save().await?;
1059+
// Redirect to the user registration flow, in case we have any other step to
1060+
// finish
1061+
Ok((
1062+
cookie_jar,
1063+
url_builder.redirect(&mas_router::RegisterFinish::new(registration.id)),
1064+
)
1065+
.into_response())
1066+
}
10061067

1007-
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
1068+
_ => Err(RouteError::InvalidFormAction),
1069+
}
10081070
}
10091071

10101072
#[cfg(test)]
@@ -1013,20 +1075,18 @@ mod tests {
10131075
use mas_data_model::{
10141076
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProviderClaimsImports,
10151077
UpstreamOAuthProviderImportPreference, UpstreamOAuthProviderLocalpartPreference,
1016-
UpstreamOAuthProviderTokenAuthMethod,
1078+
UpstreamOAuthProviderTokenAuthMethod, UserEmailAuthentication, UserRegistration,
10171079
};
10181080
use mas_iana::jose::JsonWebSignatureAlg;
10191081
use mas_jose::jwt::{JsonWebSignatureHeader, Jwt};
10201082
use mas_keystore::Keystore;
10211083
use mas_router::Route;
1022-
use mas_storage::{
1023-
Pagination, Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams,
1024-
user::UserEmailFilter,
1025-
};
1084+
use mas_storage::{Repository, RepositoryError, upstream_oauth2::UpstreamOAuthProviderParams};
10261085
use oauth2_types::scope::{OPENID, Scope};
10271086
use rand_chacha::ChaChaRng;
10281087
use serde_json::Value;
10291088
use sqlx::PgPool;
1089+
use ulid::Ulid;
10301090

10311091
use super::UpstreamSessionsCookie;
10321092
use crate::test_utils::{CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
@@ -1188,33 +1248,41 @@ mod tests {
11881248
let response = state.request(request).await;
11891249
cookies.save_cookies(&response);
11901250
response.assert_status(StatusCode::SEE_OTHER);
1251+
let location = response.headers().get(hyper::header::LOCATION).unwrap();
1252+
// Grab the registration ID from the redirected URL:
1253+
// /register/steps/{id}/finish
1254+
let registration_id: Ulid = str::from_utf8(location.as_bytes())
1255+
.unwrap()
1256+
.rsplit('/')
1257+
.nth(1)
1258+
.expect("Location to have two slashes")
1259+
.parse()
1260+
.expect("last segment of location to be a ULID");
11911261

11921262
// Check that we have a registered user, with the email imported
11931263
let mut repo = state.repository().await.unwrap();
1194-
let user = repo
1195-
.user()
1196-
.find_by_username("john")
1197-
.await
1198-
.unwrap()
1199-
.expect("user exists");
1200-
1201-
let link = repo
1202-
.upstream_oauth_link()
1203-
.find_by_subject(&provider, "subject")
1264+
let registration: UserRegistration = repo
1265+
.user_registration()
1266+
.lookup(registration_id)
12041267
.await
12051268
.unwrap()
1206-
.expect("link exists");
1269+
.expect("user registration exists");
12071270

1208-
assert_eq!(link.user_id, Some(user.id));
1271+
assert_eq!(registration.password, None);
1272+
assert_eq!(registration.completed_at, None);
1273+
assert_eq!(registration.username, "john");
12091274

1210-
let page = repo
1275+
let email_auth_id = registration
1276+
.email_authentication_id
1277+
.expect("registration should have an email authentication");
1278+
let email_auth: UserEmailAuthentication = repo
12111279
.user_email()
1212-
.list(UserEmailFilter::new().for_user(&user), Pagination::first(1))
1280+
.lookup_authentication(email_auth_id)
12131281
.await
1214-
.unwrap();
1215-
let edge = page.edges.first().expect("email exists");
1216-
1217-
assert_eq!(edge.node.email, "john@example.com");
1282+
.unwrap()
1283+
.expect("email authentication should exist");
1284+
assert_eq!(email_auth.email, "john@example.com");
1285+
assert!(email_auth.completed_at.is_some());
12181286
}
12191287

12201288
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]

crates/handlers/src/views/register/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ mod cookie;
2121
pub(crate) mod password;
2222
pub(crate) mod steps;
2323

24+
pub use self::cookie::UserRegistrationSessions as UserRegistrationSessionsCookie;
25+
2426
#[tracing::instrument(name = "handlers.views.register.get", skip_all)]
2527
pub(crate) async fn get(
2628
mut rng: BoxRng,

0 commit comments

Comments
 (0)