Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/cli/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ pub async fn policy_factory_from_config(
register: config.register_entrypoint.clone(),
client_registration: config.client_registration_entrypoint.clone(),
authorization_grant: config.authorization_grant_entrypoint.clone(),
compat_login: config.compat_login_entrypoint.clone(),
email: config.email_entrypoint.clone(),
};

Expand Down
16 changes: 16 additions & 0 deletions crates/config/src/sections/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ fn is_default_password_entrypoint(value: &String) -> bool {
*value == default_password_entrypoint()
}

fn default_compat_login_entrypoint() -> String {
"compat_login/violation".to_owned()
}

fn is_default_compat_login_entrypoint(value: &String) -> bool {
*value == default_compat_login_entrypoint()
}

fn default_email_entrypoint() -> String {
"email/violation".to_owned()
}
Expand Down Expand Up @@ -111,6 +119,13 @@ pub struct PolicyConfig {
)]
pub authorization_grant_entrypoint: String,

/// Entrypoint to use when evaluating compatibility logins
#[serde(
default = "default_compat_login_entrypoint",
skip_serializing_if = "is_default_compat_login_entrypoint"
)]
pub compat_login_entrypoint: String,

/// Entrypoint to use when changing password
#[serde(
default = "default_password_entrypoint",
Expand All @@ -137,6 +152,7 @@ impl Default for PolicyConfig {
client_registration_entrypoint: default_client_registration_entrypoint(),
register_entrypoint: default_register_entrypoint(),
authorization_grant_entrypoint: default_authorization_grant_entrypoint(),
compat_login_entrypoint: default_compat_login_entrypoint(),
password_entrypoint: default_password_entrypoint(),
email_entrypoint: default_email_entrypoint(),
data: default_data(),
Expand Down
94 changes: 92 additions & 2 deletions crates/handlers/src/compat/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use mas_data_model::{
User,
};
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, Requester, ViolationCode, model::CompatLogin};
use mas_storage::{
BoxRepository, BoxRepositoryFactory, RepositoryAccess,
compat::{
Expand All @@ -37,6 +38,7 @@ use crate::{
BoundActivityTracker, Limiter, METER, RequesterFingerprint, impl_from_error_for_route,
passwords::{PasswordManager, PasswordVerificationResult},
rate_limit::PasswordCheckLimitedError,
session::count_user_sessions_for_limiting,
};

static LOGIN_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
Expand Down Expand Up @@ -213,9 +215,16 @@ pub enum RouteError {

#[error("failed to provision device")]
ProvisionDeviceFailed(#[source] anyhow::Error),

#[error("login rejected by policy")]
PolicyRejected,

#[error("login rejected by policy (hard session limit reached)")]
PolicyHardSessionLimitReached,
}

impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::EvaluationError);

impl From<anyhow::Error> for RouteError {
fn from(err: anyhow::Error) -> Self {
Expand Down Expand Up @@ -274,6 +283,16 @@ impl IntoResponse for RouteError {
error: "User account has been locked",
status: StatusCode::UNAUTHORIZED,
},
Self::PolicyRejected => MatrixError {
errcode: "M_FORBIDDEN",
error: "Login denied by the policy enforced by this service",
status: StatusCode::FORBIDDEN,
},
Self::PolicyHardSessionLimitReached => MatrixError {
errcode: "M_FORBIDDEN",
error: "You have reached your hard device limit. Please visit your account page to sign some out.",
status: StatusCode::FORBIDDEN,
},
};

(sentry_event_id, response).into_response()
Expand All @@ -290,6 +309,7 @@ pub(crate) async fn post(
State(homeserver): State<Arc<dyn HomeserverConnection>>,
State(site_config): State<SiteConfig>,
State(limiter): State<Limiter>,
mut policy: Policy,
requester: RequesterFingerprint,
user_agent: Option<TypedHeader<headers::UserAgent>>,
MatrixJsonBody(input): MatrixJsonBody<RequestBody>,
Expand Down Expand Up @@ -329,6 +349,11 @@ pub(crate) async fn post(
&limiter,
requester,
&mut repo,
&mut policy,
Requester {
ip_address: activity_tracker.ip(),
user_agent: user_agent.clone(),
},
username,
password,
input.device_id, // TODO check for validity
Expand All @@ -342,6 +367,11 @@ pub(crate) async fn post(
&mut rng,
&clock,
&mut repo,
&mut policy,
Requester {
ip_address: activity_tracker.ip(),
user_agent: user_agent.clone(),
},
&token,
input.device_id,
input.initial_device_display_name,
Expand Down Expand Up @@ -459,6 +489,8 @@ async fn token_login(
rng: &mut (dyn RngCore + Send),
clock: &dyn Clock,
repo: &mut BoxRepository,
policy: &mut Policy,
requester: Requester,
token: &str,
requested_device_id: Option<String>,
initial_device_display_name: Option<String>,
Expand Down Expand Up @@ -544,10 +576,38 @@ async fn token_login(
Device::generate(rng)
};

repo.app_session()
let session_replaced = repo
.app_session()
.finish_sessions_to_replace_device(clock, &browser_session.user, &device)
.await?;

let session_counts = count_user_sessions_for_limiting(repo, &browser_session.user).await?;

let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &browser_session.user,
login: CompatLogin::Token,
session_replaced,
session_counts,
requester,
})
.await?;
if !res.valid() {
// If the only violation is that we have too many sessions, then handle that
// separately.
// In the future, we intend to evict some sessions automatically instead. We
// don't trigger this if there was some other violation anyway, since that means
// that removing a session wouldn't actually unblock the login.
if res.violations.len() == 1 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reasoning behind having a different message only if there is a single "too-many-sessions" violation vs. if it is one of them in the list?

In any case, I think this would benefit from a comment of what we're doing and why

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely should comment, but I think the underlying thought was: If you have too many sessions but also your client 'Tnemele' is banned, isn't showing you the UI for 'pick a device to delete' missing the point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this comment to 8f523e3

let violation = &res.violations[0];
if violation.code == Some(ViolationCode::TooManySessions) {
// The only violation is having reached the session limit.
return Err(RouteError::PolicyHardSessionLimitReached);
}
}
return Err(RouteError::PolicyRejected);
}

// We first create the session in the database, commit the transaction, then
// create it on the homeserver, scheduling a device sync job afterwards to
// make sure we don't end up in an inconsistent state.
Expand Down Expand Up @@ -578,6 +638,8 @@ async fn user_password_login(
limiter: &Limiter,
requester: RequesterFingerprint,
repo: &mut BoxRepository,
policy: &mut Policy,
policy_requester: Requester,
username: &str,
password: String,
requested_device_id: Option<String>,
Expand Down Expand Up @@ -647,10 +709,38 @@ async fn user_password_login(
Device::generate(&mut rng)
};

repo.app_session()
let session_replaced = repo
.app_session()
.finish_sessions_to_replace_device(clock, &user, &device)
.await?;

let session_counts = count_user_sessions_for_limiting(repo, &user).await?;

let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &user,
login: CompatLogin::Password,
session_replaced,
session_counts,
requester: policy_requester,
})
.await?;
if !res.valid() {
// If the only violation is that we have too many sessions, then handle that
// separately.
// In the future, we intend to evict some sessions automatically instead. We
// don't trigger this if there was some other violation anyway, since that means
// that removing a session wouldn't actually unblock the login.
if res.violations.len() == 1 {
let violation = &res.violations[0];
if violation.code == Some(ViolationCode::TooManySessions) {
// The only violation is having reached the session limit.
return Err(RouteError::PolicyHardSessionLimitReached);
}
}
return Err(RouteError::PolicyRejected);
}

let session = repo
.compat_session()
.add(
Expand Down
82 changes: 78 additions & 4 deletions crates/handlers/src/compat/login_sso_complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,27 @@ use axum::{
extract::{Form, Path, State},
response::{Html, IntoResponse, Redirect, Response},
};
use axum_extra::extract::Query;
use axum_extra::{TypedHeader, extract::Query};
use chrono::Duration;
use hyper::StatusCode;
use mas_axum_utils::{
InternalError,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
use mas_data_model::{BoxClock, BoxRng, Clock};
use mas_policy::{Policy, model::CompatLogin};
use mas_router::{CompatLoginSsoAction, UrlBuilder};
use mas_storage::{BoxRepository, RepositoryAccess, compat::CompatSsoLoginRepository};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use mas_templates::{
CompatLoginPolicyViolationContext, CompatSsoContext, ErrorContext, TemplateContext, Templates,
};
use serde::{Deserialize, Serialize};
use ulid::Ulid;

use crate::{
PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
BoundActivityTracker, PreferredLanguage,
session::{SessionOrFallback, count_user_sessions_for_limiting, load_session_or_fallback},
};

#[derive(Serialize)]
Expand Down Expand Up @@ -56,10 +60,15 @@ pub async fn get(
mut repo: BoxRepository,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
user_agent: Option<TypedHeader<headers::UserAgent>>,
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
Query(params): Query<Params>,
) -> Result<Response, InternalError> {
let user_agent = user_agent.map(|ua| ua.to_string());

let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
Expand Down Expand Up @@ -107,6 +116,35 @@ pub async fn get(
return Ok((cookie_jar, Html(content)).into_response());
}

let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;

let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &session.user,
login: CompatLogin::Sso {
redirect_uri: login.redirect_uri.to_string(),
},
// We don't know if there's going to be a replacement until we received the device ID,
// which happens too late.
session_replaced: false,
session_counts,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent,
},
})
.await?;
if !res.valid() {
let ctx = CompatLoginPolicyViolationContext::for_violations(res.violations)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);

let content = templates.render_compat_login_policy_violation(&ctx)?;

return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
}

let ctx = CompatSsoContext::new(login)
.with_session(session)
.with_csrf(csrf_token.form_value())
Expand All @@ -129,11 +167,16 @@ pub async fn post(
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
user_agent: Option<TypedHeader<headers::UserAgent>>,
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
Query(params): Query<Params>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, InternalError> {
let user_agent = user_agent.map(|ua| ua.to_string());

let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
Expand Down Expand Up @@ -200,6 +243,37 @@ pub async fn post(
redirect_uri
};

let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?;

let res = policy
.evaluate_compat_login(mas_policy::CompatLoginInput {
user: &session.user,
login: CompatLogin::Sso {
redirect_uri: login.redirect_uri.to_string(),
},
session_counts,
// We don't know if there's going to be a replacement until we received the device ID,
// which happens too late.
session_replaced: false,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent,
},
})
.await?;

if !res.valid() {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let ctx = CompatLoginPolicyViolationContext::for_violations(res.violations)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);

let content = templates.render_compat_login_policy_violation(&ctx)?;

return Ok((StatusCode::FORBIDDEN, cookie_jar, Html(content)).into_response());
}

// Note that if the login is not Pending,
// this fails and aborts the transaction.
repo.compat_sso_login()
Expand Down
1 change: 1 addition & 0 deletions crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ where
BoxRepository: FromRequestParts<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
{
// A sub-router for human-facing routes with error handling
let human_router = Router::new()
Expand Down
1 change: 1 addition & 0 deletions crates/handlers/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub(crate) async fn policy_factory(
register: "register/violation".to_owned(),
client_registration: "client_registration/violation".to_owned(),
authorization_grant: "authorization_grant/violation".to_owned(),
compat_login: "compat_login/violation".to_owned(),
email: "email/violation".to_owned(),
};

Expand Down
Loading
Loading