@@ -26,7 +26,6 @@ use mas_policy::Policy;
2626use mas_router:: UrlBuilder ;
2727use mas_storage:: {
2828 BoxRepository , RepositoryAccess ,
29- queue:: { ProvisionUserJob , QueueJobRepositoryExt as _} ,
3029 upstream_oauth2:: { UpstreamOAuthLinkRepository , UpstreamOAuthSessionRepository } ,
3130 user:: { BrowserSessionRepository , UserEmailRepository , UserRepository } ,
3231} ;
@@ -46,7 +45,7 @@ use super::{
4645} ;
4746use crate :: {
4847 BoundActivityTracker , METER , PreferredLanguage , SiteConfig , impl_from_error_for_route,
49- views:: shared:: OptionalPostAuthAction ,
48+ views:: { register :: UserRegistrationSessionsCookie , shared:: OptionalPostAuthAction } ,
5049} ;
5150
5251static LOGIN_COUNTER : LazyLock < Counter < u64 > > = LazyLock :: new ( || {
@@ -610,10 +609,6 @@ pub(crate) async fn post(
610609 . lookup_link ( link_id)
611610 . map_err ( |_| RouteError :: MissingCookie ) ?;
612611
613- let post_auth_action = OptionalPostAuthAction {
614- post_auth_action : post_auth_action. cloned ( ) ,
615- } ;
616-
617612 let link = repo
618613 . upstream_oauth_link ( )
619614 . lookup ( link_id)
@@ -641,15 +636,35 @@ pub(crate) async fn post(
641636 let maybe_user_session = user_session_info. load_active_session ( & mut repo) . await ?;
642637 let form_state = form. to_form_state ( ) ;
643638
644- let session = match ( maybe_user_session, link. user_id , form) {
639+ match ( maybe_user_session, link. user_id , form) {
645640 ( Some ( session) , None , FormData :: Link ) => {
646641 // The user is already logged in, the link is not linked to any user, and the
647642 // user asked to link their account.
648643 repo. upstream_oauth_link ( )
649644 . associate_to_user ( & link, & session. user )
650645 . await ?;
651646
652- session
647+ let upstream_session = repo
648+ . upstream_oauth_session ( )
649+ . consume ( & clock, upstream_session)
650+ . await ?;
651+
652+ repo. browser_session ( )
653+ . authenticate_with_upstream ( & mut rng, & clock, & session, & upstream_session)
654+ . await ?;
655+
656+ let post_auth_action = OptionalPostAuthAction {
657+ post_auth_action : post_auth_action. cloned ( ) ,
658+ } ;
659+
660+ let cookie_jar = sessions_cookie
661+ . consume_link ( link_id) ?
662+ . save ( cookie_jar, & clock) ;
663+ let cookie_jar = cookie_jar. set_session ( & session) ;
664+
665+ repo. save ( ) . await ?;
666+
667+ Ok ( ( cookie_jar, post_auth_action. go_next ( & url_builder) ) . into_response ( ) )
653668 }
654669
655670 ( None , None , FormData :: Link ) => {
@@ -714,14 +729,38 @@ pub(crate) async fn post(
714729 return Err ( RouteError :: InvalidFormAction ) ;
715730 }
716731 UpstreamOAuthProviderOnConflict :: Add => {
717- //add link to the user
732+ // Add link to the user
718733 repo. upstream_oauth_link ( )
719734 . associate_to_user ( & link, & user)
720735 . await ?;
721736
722- repo. browser_session ( )
737+ // And sign in the user
738+ let session = repo
739+ . browser_session ( )
723740 . add ( & mut rng, & clock, & user, user_agent)
724- . await ?
741+ . await ?;
742+
743+ let upstream_session = repo
744+ . upstream_oauth_session ( )
745+ . consume ( & clock, upstream_session)
746+ . await ?;
747+
748+ repo. browser_session ( )
749+ . authenticate_with_upstream ( & mut rng, & clock, & session, & upstream_session)
750+ . await ?;
751+
752+ let post_auth_action = OptionalPostAuthAction {
753+ post_auth_action : post_auth_action. cloned ( ) ,
754+ } ;
755+
756+ let cookie_jar = sessions_cookie
757+ . consume_link ( link_id) ?
758+ . save ( cookie_jar, & clock) ;
759+ let cookie_jar = cookie_jar. set_session ( & session) ;
760+
761+ repo. save ( ) . await ?;
762+
763+ Ok ( ( cookie_jar, post_auth_action. go_next ( & url_builder) ) . into_response ( ) )
725764 }
726765 }
727766 }
@@ -950,61 +989,84 @@ pub(crate) async fn post(
950989
951990 REGISTRATION_COUNTER . add ( 1 , & [ KeyValue :: new ( PROVIDER , provider. id . to_string ( ) ) ] ) ;
952991
953- // Now we can create the user
954- let user = repo. user ( ) . add ( & mut rng, & clock, username) . await ?;
992+ let mut registration = repo
993+ . user_registration ( )
994+ . add (
995+ & mut rng,
996+ & clock,
997+ username,
998+ activity_tracker. ip ( ) ,
999+ user_agent,
1000+ post_auth_action. map ( |action| serde_json:: json!( action) ) ,
1001+ )
1002+ . await ?;
9551003
9561004 if let Some ( terms_url) = & site_config. tos_uri {
957- repo. user_terms ( )
958- . accept_terms ( & mut rng, & clock, & user, terms_url. clone ( ) )
1005+ registration = repo
1006+ . user_registration ( )
1007+ . set_terms_url ( registration, terms_url. clone ( ) )
9591008 . await ?;
9601009 }
9611010
962- // And schedule the job to provision it
963- let mut job = ProvisionUserJob :: new ( & user) ;
1011+ // If we have an email, add an email authentication and complete it
1012+ if let Some ( email) = email {
1013+ let authentication = repo
1014+ . user_email ( )
1015+ . add_authentication_for_registration ( & mut rng, & clock, email, & registration)
1016+ . await ?;
1017+ let authentication = repo
1018+ . user_email ( )
1019+ . complete_authentication_with_upstream (
1020+ & clock,
1021+ authentication,
1022+ & upstream_session,
1023+ )
1024+ . await ?;
9641025
965- // If we have a display name, set it during provisioning
966- if let Some ( name) = display_name {
967- job = job. set_display_name ( name) ;
1026+ registration = repo
1027+ . user_registration ( )
1028+ . set_email_authentication ( registration, & authentication)
1029+ . await ?;
9681030 }
9691031
970- repo. queue_job ( ) . schedule_job ( & mut rng, & clock, job) . await ?;
971-
972- // If we have an email, add it to the user
973- if let Some ( email) = email {
974- repo. user_email ( )
975- . add ( & mut rng, & clock, & user, email)
1032+ // If we have a display name, add it to the registration
1033+ if let Some ( name) = display_name {
1034+ registration = repo
1035+ . user_registration ( )
1036+ . set_display_name ( registration, name)
9761037 . await ?;
9771038 }
9781039
979- repo. upstream_oauth_link ( )
980- . associate_to_user ( & link, & user)
1040+ let registration = repo
1041+ . user_registration ( )
1042+ . set_upstream_oauth_authorization_session ( registration, & upstream_session)
9811043 . await ?;
9821044
983- repo. browser_session ( )
984- . add ( & mut rng, & clock, & user, user_agent)
985- . await ?
986- }
1045+ repo. upstream_oauth_session ( )
1046+ . consume ( & clock, upstream_session)
1047+ . await ?;
9871048
988- _ => return Err ( RouteError :: InvalidFormAction ) ,
989- } ;
1049+ let registrations = UserRegistrationSessionsCookie :: load ( & cookie_jar) ;
9901050
991- let upstream_session = repo
992- . upstream_oauth_session ( )
993- . consume ( & clock, upstream_session)
994- . await ?;
1051+ let cookie_jar = sessions_cookie
1052+ . consume_link ( link_id) ?
1053+ . save ( cookie_jar, & clock) ;
9951054
996- repo. browser_session ( )
997- . authenticate_with_upstream ( & mut rng, & clock, & session, & upstream_session)
998- . await ?;
1055+ let cookie_jar = registrations. add ( & registration) . save ( cookie_jar, & clock) ;
9991056
1000- let cookie_jar = sessions_cookie
1001- . consume_link ( link_id) ?
1002- . save ( cookie_jar, & clock) ;
1003- let cookie_jar = cookie_jar. set_session ( & session) ;
1057+ repo. save ( ) . await ?;
10041058
1005- repo. save ( ) . await ?;
1059+ // Redirect to the user registration flow, in case we have any other step to
1060+ // finish
1061+ Ok ( (
1062+ cookie_jar,
1063+ url_builder. redirect ( & mas_router:: RegisterFinish :: new ( registration. id ) ) ,
1064+ )
1065+ . into_response ( ) )
1066+ }
10061067
1007- Ok ( ( cookie_jar, post_auth_action. go_next ( & url_builder) ) . into_response ( ) )
1068+ _ => Err ( RouteError :: InvalidFormAction ) ,
1069+ }
10081070}
10091071
10101072#[ cfg( test) ]
@@ -1013,20 +1075,18 @@ mod tests {
10131075 use mas_data_model:: {
10141076 UpstreamOAuthAuthorizationSession , UpstreamOAuthLink , UpstreamOAuthProviderClaimsImports ,
10151077 UpstreamOAuthProviderImportPreference , UpstreamOAuthProviderLocalpartPreference ,
1016- UpstreamOAuthProviderTokenAuthMethod ,
1078+ UpstreamOAuthProviderTokenAuthMethod , UserEmailAuthentication , UserRegistration ,
10171079 } ;
10181080 use mas_iana:: jose:: JsonWebSignatureAlg ;
10191081 use mas_jose:: jwt:: { JsonWebSignatureHeader , Jwt } ;
10201082 use mas_keystore:: Keystore ;
10211083 use mas_router:: Route ;
1022- use mas_storage:: {
1023- Pagination , Repository , RepositoryError , upstream_oauth2:: UpstreamOAuthProviderParams ,
1024- user:: UserEmailFilter ,
1025- } ;
1084+ use mas_storage:: { Repository , RepositoryError , upstream_oauth2:: UpstreamOAuthProviderParams } ;
10261085 use oauth2_types:: scope:: { OPENID , Scope } ;
10271086 use rand_chacha:: ChaChaRng ;
10281087 use serde_json:: Value ;
10291088 use sqlx:: PgPool ;
1089+ use ulid:: Ulid ;
10301090
10311091 use super :: UpstreamSessionsCookie ;
10321092 use crate :: test_utils:: { CookieHelper , RequestBuilderExt , ResponseExt , TestState , setup} ;
@@ -1188,33 +1248,41 @@ mod tests {
11881248 let response = state. request ( request) . await ;
11891249 cookies. save_cookies ( & response) ;
11901250 response. assert_status ( StatusCode :: SEE_OTHER ) ;
1251+ let location = response. headers ( ) . get ( hyper:: header:: LOCATION ) . unwrap ( ) ;
1252+ // Grab the registration ID from the redirected URL:
1253+ // /register/steps/{id}/finish
1254+ let registration_id: Ulid = str:: from_utf8 ( location. as_bytes ( ) )
1255+ . unwrap ( )
1256+ . rsplit ( '/' )
1257+ . nth ( 1 )
1258+ . expect ( "Location to have two slashes" )
1259+ . parse ( )
1260+ . expect ( "last segment of location to be a ULID" ) ;
11911261
11921262 // Check that we have a registered user, with the email imported
11931263 let mut repo = state. repository ( ) . await . unwrap ( ) ;
1194- let user = repo
1195- . user ( )
1196- . find_by_username ( "john" )
1197- . await
1198- . unwrap ( )
1199- . expect ( "user exists" ) ;
1200-
1201- let link = repo
1202- . upstream_oauth_link ( )
1203- . find_by_subject ( & provider, "subject" )
1264+ let registration: UserRegistration = repo
1265+ . user_registration ( )
1266+ . lookup ( registration_id)
12041267 . await
12051268 . unwrap ( )
1206- . expect ( "link exists" ) ;
1269+ . expect ( "user registration exists" ) ;
12071270
1208- assert_eq ! ( link. user_id, Some ( user. id) ) ;
1271+ assert_eq ! ( registration. password, None ) ;
1272+ assert_eq ! ( registration. completed_at, None ) ;
1273+ assert_eq ! ( registration. username, "john" ) ;
12091274
1210- let page = repo
1275+ let email_auth_id = registration
1276+ . email_authentication_id
1277+ . expect ( "registration should have an email authentication" ) ;
1278+ let email_auth: UserEmailAuthentication = repo
12111279 . user_email ( )
1212- . list ( UserEmailFilter :: new ( ) . for_user ( & user ) , Pagination :: first ( 1 ) )
1280+ . lookup_authentication ( email_auth_id )
12131281 . await
1214- . unwrap ( ) ;
1215- let edge = page . edges . first ( ) . expect ( "email exists " ) ;
1216-
1217- assert_eq ! ( edge . node . email , "john@example.com" ) ;
1282+ . unwrap ( )
1283+ . expect ( "email authentication should exist " ) ;
1284+ assert_eq ! ( email_auth . email , "john@example.com" ) ;
1285+ assert ! ( email_auth . completed_at . is_some ( ) ) ;
12181286 }
12191287
12201288 #[ sqlx:: test( migrator = "mas_storage_pg::MIGRATOR" ) ]
0 commit comments