diff --git a/src/store.rs b/src/store.rs index 0ae2db3..43b86b6 100644 --- a/src/store.rs +++ b/src/store.rs @@ -110,7 +110,7 @@ pub struct UserInfo { /// A wrapper around a database transaction with some database methods on it. pub struct IdCoopStoreTxn<'a, 'txn> { - txn: &'a mut Transaction<'txn, Postgres>, + pub(crate) txn: &'a mut Transaction<'txn, Postgres>, } impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> { @@ -341,7 +341,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> { .await .context("failed to lookup login session")?; - let Some(row) = row_opt else { return Ok(None); }; + let Some(row) = row_opt else { + return Ok(None); + }; Ok(Some(LoginSession { user_name: row.user_name, @@ -373,7 +375,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> { .await .context("failed to lookup application session")?; - let Some(row) = row_opt else { return Ok(None); }; + let Some(row) = row_opt else { + return Ok(None); + }; Ok(Some(ApplicationSession { application_session_id: row.session_id, diff --git a/src/tests.rs b/src/tests.rs index 557016c..7da6ba2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use axum::Router; use confique::{Config, Partial}; use josekit::jwk::alg::rsa::RsaKeyPair; +use metrics::atomics::AtomicU64; use pgtemp::PgTempDB; use rand::SeedableRng; use rand_xoshiro::Xoshiro256StarStar; @@ -20,6 +21,7 @@ struct TestSystem { web: Router, config: Arc, store: Arc, + clock: Clock, } const RSA_KEY_PAIR_PEM: &[u8] = include_bytes!("tests/keypair.pem"); @@ -85,13 +87,13 @@ async fn basic_system() -> TestSystem { let config = Arc::new(config); let store = Arc::new(store); - let clock = Clock::Fake(); + let clock = Clock(Arc::new(AtomicU64::new(0))); let randgen = RandGen(Xoshiro256StarStar::seed_from_u64(424242)); let router = make_router( store.clone(), config.clone(), Arc::new(secrets), - clock, + clock.clone(), randgen, ) .await @@ -102,6 +104,7 @@ async fn basic_system() -> TestSystem { web: router, config, store, + clock, } } diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__login.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__login.snap index 20336ee..9f2db4a 100644 --- a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__login.snap +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__login.snap @@ -4,7 +4,7 @@ expression: "(headers, text)" --- - content-length: "55" content-type: text/plain; charset=utf-8 - location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=challenging&code_challenge_method=S256&nonce=noncey" + location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M&code_challenge_method=S256&nonce=noncey" set-cookie: __Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000 x-frame-options: DENY - Logged in. Redirecting you back to what you were doing. diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__auth.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__auth.snap new file mode 100644 index 0000000..5a4c2fc --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__auth.snap @@ -0,0 +1,8 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- content-length: "288" + content-type: text/html; charset=utf-8 + x-frame-options: DENY +- "hi robert, consent to AClient?
" diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__auth.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__auth.snap new file mode 100644 index 0000000..ad770d6 --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__auth.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- content-length: "46" + content-type: text/plain; charset=utf-8 + location: "http://aclient.example.com/redirect?code=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO&state=wombat&iss=http%3A%2F%2Fissuer.example.com" + x-frame-options: DENY +- Authorisation succeeded; redirecting you back. diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token.snap new file mode 100644 index 0000000..776bebb --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token.snap @@ -0,0 +1,14 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, json)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "803" + content-type: application/json +- access_token: HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE + expires_in: 31536000 + id_token: eyJ0eXAiOiJKV1QiLCJraWQiOiJ0aGVrZXkiLCJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwic3ViIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwIiwiYXVkIjoiYWNsaWVudCIsImV4cCI6MzE1MzYwMDAzMCwiaWF0IjozMCwiYXV0aF90aW1lIjowLCJub25jZSI6Im5vbmNleSJ9.QQEhDgAcF2vBg2J6ledDzk_4ks4GyquJgMSE4KUREtTUVZbLpfa52sro8lPiBnFPCOz_DkSfpm4OQq8429mwcqfoyS-uBjtgPq7eij7kOa3BTrb9eC8rScGuDX0wJ9XZV-v0f3dun_sYhvH3smLqPoTF4wxtgT5b_2SCmnuqL2cmKN-GFox4mjmdoPzQxhAyTKlj_HkGHQjkl-nP96-71QeM5KwyLQes_OWU2HSEt9uiemUsEr4pMPv-po7QkrU5p2sJc5udcaUAOtuV9tpt5qg8P9TPWYo4M1GbMsbTyWYDhsmtusNKB6N2srwZwB9QwgE4DxoeoqmKlJf0BGF4Bg + refresh_token: bn0vvGrpCiQDMaH7MRw7FgHuPqoEYel0IFGYv_Z6E7o + scope: openid + token_type: Bearer diff --git a/src/tests/test_oidc_auth_flow.rs b/src/tests/test_oidc_auth_flow.rs index d768ea6..6584235 100644 --- a/src/tests/test_oidc_auth_flow.rs +++ b/src/tests/test_oidc_auth_flow.rs @@ -7,6 +7,7 @@ use axum_test_helper::{TestClient, TestResponse}; use insta::assert_yaml_snapshot; use maplit::btreemap; +use sqlx::types::Uuid; use crate::{passwords::create_password_hash, store::CreateUser, tests::basic_system}; @@ -21,10 +22,15 @@ async fn dump_resp_text( .headers() .clone() .into_iter() - .map(|(k, v)| (k.unwrap().to_string(), v.to_str().unwrap().to_owned())) + // skip headers that are duplicate keys. + // We don't care about theme for our purposes. + .filter_map(|(k, v)| Some((k?.to_string(), v.to_str().unwrap().to_owned()))) .collect(); // Remove date because it's not stable across tests! headers.remove("date"); + // Remove vary because it has multiple values and we don't want to + // introduce instability into our tests by only allowing one through. + headers.remove("vary"); let text = resp.text().await; eprintln!("=== Response for {req_name} ==="); eprintln!("Status: {status:?}"); @@ -36,21 +42,20 @@ async fn dump_resp_text( /// Tests the full flow... #[tokio::test] -async fn test_todo() { +async fn test_full_flow() { let sys = basic_system().await; + let uuid = Uuid::nil(); let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap(); let _: () = sys .store .txn(|mut txn| { Box::pin(async move { - txn.create_user(CreateUser { - user_login_name: "robert".to_owned(), - password_hash: Some(pwhash), - locked: false, - }) - .await - .unwrap(); + sqlx::query( + "INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id", + ).bind("robert").bind(uuid).bind(pwhash).bind(false) + .fetch_one(&mut **txn.txn) + .await.unwrap(); Ok(()) }) }) @@ -59,22 +64,28 @@ async fn test_todo() { let client = TestClient::new(sys.web); + ///// These requests are on behalf of the user's browser ///// + + const CODE_VERIFIER: &str = "verifying"; + // base64(sha256("verifying")) + const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M"; + // 1. /auth request - const LOGIN_URL: &str = "/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3Dchallenging%26code_challenge_method%3DS256%26nonce%3Dnoncey"; - let resp = client.get("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=challenging&code_challenge_method=S256&nonce=noncey").send().await; + let login_url = format!("/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3D{CODE_CHALLENGE}%26code_challenge_method%3DS256%26nonce%3Dnoncey"); + let resp = client.get(&format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey")).send().await; let (status, headers, _text) = dump_resp_text("1. /auth request", resp).await; assert_eq!(status, 302); - assert_eq!(headers.get("location").unwrap(), LOGIN_URL); + assert_eq!(headers.get("location").unwrap(), &login_url); // 2. /login request - let resp = client.get(LOGIN_URL).send().await; + let resp = client.get(&login_url).send().await; let (status, headers, text) = dump_resp_text("2. /login request", resp).await; assert_eq!(status, 200); assert_yaml_snapshot!("2/login", (headers, text)); // 3. /login request with credentials let resp = client - .post(LOGIN_URL) + .post(&login_url) .form(&btreemap! { "username" => "robert", "password" => "secret", @@ -87,5 +98,58 @@ async fn test_todo() { .await; let (status, headers, text) = dump_resp_text("3. /login request with credentials", resp).await; assert_eq!(status, 302); + let auth_loc = headers.get("location").unwrap().to_owned(); assert_yaml_snapshot!("3/login", (headers, text)); + + // 4. /auth request + let resp = client + .get(&auth_loc) + .header( + "Cookie", + "__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE", + ) + .send() + .await; + let (status, headers, text) = dump_resp_text("4. GET /auth after login", resp).await; + assert_eq!(status, 200); + assert_yaml_snapshot!("4/auth", (headers, text)); + + sys.clock.set_time(30); + + // 5. /auth request with confirmation + let resp = client + .post(&auth_loc) + .header( + "Cookie", + "__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE", + ) + .form(&btreemap! { + "action" => "accept", + "xsrf" => "0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("5. POST /auth after confirmation", resp).await; + assert_eq!(status, 302); + // Note this snapshot includes a Location: header back to the client application + assert_yaml_snapshot!("5/auth", (headers, text)); + + ///// At this point, we make requests on behalf of the client ///// + + // 6. /token request + let resp = client + .post("/oidc/token") + .header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB") + .form(&btreemap! { + "code" => "HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO", + "code_verifier" => CODE_VERIFIER, + "grant_type" => "authorization_code", + "redirect_uri" => "http://aclient.example.com/redirect", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("6. POST /token", resp).await; + assert_eq!(status, 200); + let json: BTreeMap = serde_json::from_str(&text).unwrap(); + assert_yaml_snapshot!("6/token", (headers, json)); } diff --git a/src/utils.rs b/src/utils.rs index 22f03af..d9b5b5b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,7 +1,16 @@ //! Miscellaneous utilities +#[cfg(not(test))] +pub use self::real_utils::{Clock, RandGen}; + +#[cfg(test)] +pub use self::test_utils::{Clock, RandGen}; + #[cfg(not(test))] mod real_utils { + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use chrono::{DateTime, Utc}; use rand::{thread_rng, RngCore}; /// A source of random numbers that can be faked for tests. @@ -25,12 +34,47 @@ mod real_utils { thread_rng().try_fill_bytes(dest) } } + + /// A source of time that can be faked for tests. + #[derive(Clone)] + pub struct Clock; + + impl Clock { + /// Returns the current time as a `DateTime`. + pub fn now_utc(&self) -> DateTime { + Utc::now() + } + + /// Returns the current time as a i64 of seconds since the Unix epoch. + pub fn now_timestamp(&self) -> i64 { + Utc::now().timestamp() + } + + /// Sleep until the given timestamp. + pub async fn sleep_until(&self, until_ts: i64) { + if until_ts < 0 { + return; + } + let sleep_until = UNIX_EPOCH + Duration::from_secs(until_ts as u64); + let now = SystemTime::now(); + let sleep_for = sleep_until + .duration_since(now) + .unwrap_or(Duration::from_secs(0)); + tokio::time::sleep(sleep_for).await + } + } } #[cfg(test)] mod test_utils { - use std::ops::{Deref, DerefMut}; + use std::sync::atomic::AtomicU64; + use std::time::Duration; + use std::{ + ops::{Deref, DerefMut}, + sync::Arc, + }; + use chrono::{DateTime, TimeZone, Utc}; use rand_xoshiro::Xoshiro256StarStar; #[derive(Clone)] @@ -48,19 +92,37 @@ mod test_utils { &mut self.0 } } -} - -#[cfg(not(test))] -pub use self::real_utils::RandGen; - -#[cfg(test)] -pub use self::test_utils::RandGen; - -/// A source of time that can be faked for tests. -#[derive(Clone)] -pub enum Clock { - /// Use real time - Real, - /// Fake time for use in tests - Fake(), + + #[derive(Clone)] + pub struct Clock(pub Arc); + + impl Clock { + pub fn new_test() -> Self { + Clock(Arc::new(AtomicU64::new(0))) + } + pub fn set_time(&self, new: u64) { + self.0.store(new, std::sync::atomic::Ordering::Relaxed); + } + + pub fn now_utc(&self) -> DateTime { + Utc.timestamp_opt(self.0.load(std::sync::atomic::Ordering::Relaxed) as i64, 0) + .earliest() + .unwrap() + } + + pub fn now_timestamp(&self) -> i64 { + self.0.load(std::sync::atomic::Ordering::Relaxed) as i64 + } + + pub async fn sleep_until(&self, until_ts: i64) { + if until_ts < 0 { + return; + } + // Wait for time to advance past the requested timestamp + // TODO write a better test sleep implementation + while self.now_timestamp() < until_ts { + tokio::time::sleep(Duration::from_millis(1)).await; + } + } + } } diff --git a/src/web.rs b/src/web.rs index d0aaa68..281ff73 100644 --- a/src/web.rs +++ b/src/web.rs @@ -125,7 +125,7 @@ pub(crate) async fn make_router( .layer(Extension(Arc::new(PasswordHashInflightLimiter::new(1)))) .layer(client_ip_source.into_extension()) .layer(Extension(Arc::new(ratelimiters))) - .layer(Extension(VolatileCodeStore::default())) + .layer(Extension(VolatileCodeStore::new(clock.clone()))) .layer(Extension(clock)) .layer(Extension(randgen)); @@ -144,7 +144,7 @@ pub async fn serve( use eyre::Context; use tracing::info; - let router = make_router(store, config, secrets, Clock::Real, RandGen).await?; + let router = make_router(store, config, secrets, Clock, RandGen).await?; info!("Listening on {bind:?}"); axum::Server::try_bind(&bind) diff --git a/src/web/login.rs b/src/web/login.rs index f63463d..cce0727 100644 --- a/src/web/login.rs +++ b/src/web/login.rs @@ -21,7 +21,7 @@ use chrono::{DateTime, Duration, TimeZone, Utc}; use eyre::eyre; use eyre::{bail, Context, ContextCompat}; use governor::Jitter; -use rand::{thread_rng, Rng}; +use rand::Rng; use serde::Deserialize; use sqlx::types::Uuid; use tokio::sync::Semaphore; diff --git a/src/web/oauth_openid/authorisation.rs b/src/web/oauth_openid/authorisation.rs index b3a011e..b16ff7e 100644 --- a/src/web/oauth_openid/authorisation.rs +++ b/src/web/oauth_openid/authorisation.rs @@ -9,7 +9,6 @@ use axum::{ Extension, Form, }; -use chrono::Utc; use eyre::{Context, ContextCompat}; use serde::{Deserialize, Serialize}; @@ -17,6 +16,7 @@ use tracing::{error, warn}; use crate::{ config::{Configuration, OidcClientConfiguration}, + utils::{Clock, RandGen}, web::{ login::LoginSession, make_login_redirect, @@ -69,6 +69,8 @@ pub async fn oidc_authorisation( login_session: Option, Extension(config): Extension>, Extension(code_store): Extension, + Extension(clock): Extension, + Extension(mut randgen): Extension, OriginalUri(uri): OriginalUri, ) -> Response { let Query(query) = match query { @@ -109,7 +111,7 @@ pub async fn oidc_authorisation( // If the application requires consent, then we should ask for that. if !client_config.skip_consent { - return show_consent_page(login_session, client_config, &config).await; + return show_consent_page(login_session, client_config, Extension(clock), &config).await; } // No consent needed: process the authorisation. @@ -120,6 +122,8 @@ pub async fn oidc_authorisation( client_id, client_config, &config, + &mut randgen, + &clock, &code_store, ) .await @@ -133,11 +137,14 @@ pub struct PostConsentForm { } /// `POST /oidc/auth` +#[allow(clippy::too_many_arguments)] pub async fn post_oidc_authorisation_consent( Query(query): Query, login_session: Option, Extension(config): Extension>, Extension(code_store): Extension, + Extension(clock): Extension, + Extension(mut randgen): Extension, OriginalUri(uri): OriginalUri, Form(form): Form, ) -> Response { @@ -152,11 +159,11 @@ pub async fn post_oidc_authorisation_consent( }; if login_session - .validate_xsrf_token(&form.xsrf, Utc::now()) + .validate_xsrf_token(&form.xsrf, clock.now_utc()) .is_err() { // XSRF token is not valid, so show the consent form again... - return show_consent_page(login_session, client_config, &config).await; + return show_consent_page(login_session, client_config, Extension(clock), &config).await; } match form.action.as_str() { @@ -167,6 +174,8 @@ pub async fn post_oidc_authorisation_consent( client_id, client_config, &config, + &mut randgen, + &clock, &code_store, ) .await @@ -233,10 +242,11 @@ fn validate_authorisation_basics<'a>( async fn show_consent_page( login_session: LoginSession, client_config: &OidcClientConfiguration, + Extension(clock): Extension, _config: &Configuration, ) -> Response { let xsrf_token = login_session - .generate_xsrf_token(Utc::now()) + .generate_xsrf_token(clock.now_utc()) .expect("must be able to create a XSRF token"); Html(format!( "hi {}, consent to {}?
", @@ -253,12 +263,15 @@ async fn show_consent_page( /// Preconditions: /// - any required consent from the user has now been obtained /// - query.request_uri has been validated as a safe redirect URI +#[allow(clippy::too_many_arguments)] async fn process_authorisation( query: AuthorisationQuery, login_session: LoginSession, client_id: String, _client_config: &OidcClientConfiguration, config: &Configuration, + randgen: &mut RandGen, + clock: &Clock, code_store: &VolatileCodeStore, ) -> Response { assert_eq!( @@ -271,7 +284,7 @@ async fn process_authorisation( // Generate a 192-bit random code, which fits into exactly 32 base64 characters. // This is an arbitrary choice left to us but I feel a 192-bit value is sufficiently random. - let code = AuthCode::generate_new_random(); + let code = AuthCode::generate_new_random(randgen); let code_base64url = code.to_string(); // Write down the code and other details in-memory with 10 minute expiry... @@ -289,7 +302,7 @@ async fn process_authorisation( user_id: login_session.user_id, user_login_session_id: login_session.login_session_id, }, - 0, + clock.now_timestamp() + 600, ); #[derive(Serialize)] diff --git a/src/web/oauth_openid/ext_codes.rs b/src/web/oauth_openid/ext_codes.rs index 35a1bba..4c626b7 100644 --- a/src/web/oauth_openid/ext_codes.rs +++ b/src/web/oauth_openid/ext_codes.rs @@ -6,7 +6,6 @@ use std::{ fmt::Display, str::FromStr, sync::{Arc, Mutex}, - time::{Duration, SystemTime, UNIX_EPOCH}, }; use base64::{display::Base64Display, prelude::BASE64_URL_SAFE_NO_PAD, Engine}; @@ -14,6 +13,8 @@ use rand::Rng; use sqlx::types::Uuid; use tokio::sync::Notify; +use crate::utils::{Clock, RandGen}; + /// Display shows the auth code as base64 (URL-safe non-padded). /// FromStr/parse parses the same format. #[derive(Clone, Hash, PartialEq, Eq, Ord, PartialOrd)] @@ -61,8 +62,8 @@ impl FromStr for AuthCode { impl AuthCode { /// Generate a new authorisation code using the thread's RNG - pub fn generate_new_random() -> Self { - Self(rand::thread_rng().gen::<[u8; 24]>()) + pub fn generate_new_random(randgen: &mut RandGen) -> Self { + Self(randgen.gen::<[u8; 24]>()) } } @@ -112,7 +113,7 @@ struct VolatileCodeStoreInner { pub conflictable_codes: HashMap, /// Time when codes will expire - pub expire_codes_at: BTreeSet<(u64, AuthCode)>, + pub expire_codes_at: BTreeSet<(i64, AuthCode)>, } impl VolatileCodeStoreInner { @@ -148,7 +149,7 @@ impl VolatileCodeStoreInner { &mut self, auth_code: AuthCode, auth_code_binding: AuthCodeBinding, - expires_at: u64, + expires_at: i64, ) { self.redeemable_codes .insert(auth_code.clone(), auth_code_binding); @@ -156,7 +157,7 @@ impl VolatileCodeStoreInner { } /// Removes all expired auth codes and returns the time of the earliest next expiry, if present. - pub(self) fn handle_expiry(&mut self, now: u64) -> Option { + pub(self) fn handle_expiry(&mut self, now: i64) -> Option { loop { let (ts, _auth_code) = self.expire_codes_at.first()?; @@ -182,15 +183,16 @@ pub struct VolatileCodeStore { inner: Arc>, } -impl Default for VolatileCodeStore { - fn default() -> Self { +impl VolatileCodeStore { + /// Create a new instance. + pub fn new(clock: Clock) -> Self { let poke = Arc::new(Notify::new()); let inner: Arc> = Default::default(); { let poke = poke.clone(); let inner = inner.clone(); - tokio::spawn(Self::expirer(inner, poke)); + tokio::spawn(Self::expirer(inner, poke, clock)); } VolatileCodeStore { inner, poke } @@ -198,19 +200,15 @@ impl Default for VolatileCodeStore { } impl VolatileCodeStore { - async fn expirer(inner: Arc>, poke: Arc) { - let mut next_expiry: Option = None; + async fn expirer(inner: Arc>, poke: Arc, clock: Clock) { + let mut next_expiry: Option = None; loop { match next_expiry { Some(next_expiry) => { - let sleep_until = UNIX_EPOCH + Duration::from_secs(next_expiry); - let now = SystemTime::now(); - let sleep_for = sleep_until - .duration_since(now) - .unwrap_or(Duration::from_secs(60)); + let sleep_future = clock.sleep_until(next_expiry); tokio::select! { _ = poke.notified() => {}, - _ = tokio::time::sleep(sleep_for) => {}, + _ = sleep_future => {}, } } None => { @@ -218,14 +216,9 @@ impl VolatileCodeStore { } } - let now = SystemTime::now(); next_expiry = { let mut inner = inner.lock().unwrap(); - inner.handle_expiry( - now.duration_since(UNIX_EPOCH) - .expect("system clock before unix epoch") - .as_secs(), - ) + inner.handle_expiry(clock.now_timestamp()) }; } } @@ -244,7 +237,7 @@ impl VolatileCodeStore { } /// Add a new redeemable authorisation code. - pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: u64) { + pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: i64) { let mut inner = self.inner.lock().unwrap(); inner.add_redeemable(auth_code, binding, expires_at); drop(inner); @@ -284,7 +277,7 @@ mod test { use rstest::{fixture, rstest}; use sqlx::types::Uuid; - use crate::web::oauth_openid::ext_codes::CodeRedemption; + use crate::{utils::Clock, web::oauth_openid::ext_codes::CodeRedemption}; use super::{AuthCode, AuthCodeBinding, VolatileCodeStore, VolatileCodeStoreInner}; @@ -422,7 +415,8 @@ mod test { /// We can't easily cover everything here but may as well test the basics. #[tokio::test] async fn test_fullfat_store_basic() { - let vcs = VolatileCodeStore::default(); + let clock = Clock::new_test(); + let vcs = VolatileCodeStore::new(clock.clone()); vcs.add_redeemable( VALID_CODE.clone(), @@ -435,7 +429,7 @@ mod test { user_id: USER_UUID, user_login_session_id: USER_LOGIN_SESSION_ID, }, - u64::MAX, + i64::MAX, ); assert_matches!( @@ -458,6 +452,7 @@ mod test { // given a moment 1, ); + clock.set_time(2); vcs.add_redeemable( AuthCode([0; 24]), @@ -470,11 +465,11 @@ mod test { user_id: USER_UUID, user_login_session_id: USER_LOGIN_SESSION_ID, }, - u64::MAX, + i64::MAX, ); // Give a short time for the expiry to take place. - tokio::time::sleep(Duration::from_millis(1)).await; + tokio::time::sleep(Duration::from_millis(2)).await; assert_matches!( vcs.redeem(&VALID_CODE, [1; 32], [2; 32]), diff --git a/src/web/oauth_openid/token.rs b/src/web/oauth_openid/token.rs index 7a00e57..f0017c9 100644 --- a/src/web/oauth_openid/token.rs +++ b/src/web/oauth_openid/token.rs @@ -1,10 +1,6 @@ //! `/oidc/token` -use std::{ - str::FromStr, - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, -}; +use std::{str::FromStr, sync::Arc}; use axum::{ extract::rejection::FormRejection, @@ -15,13 +11,13 @@ use axum::{ }; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; use blake2::Blake2s256; -use chrono::{Duration, Utc}; +use chrono::Duration; use eyre::{bail, Context}; use josekit::{ jws::{alg::rsassa::RsassaJwsAlgorithm::Rs256, JwsHeader}, jwt::JwtPayload, }; -use rand::{thread_rng, Rng}; +use rand::Rng; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use subtle::ConstantTimeEq; @@ -30,6 +26,7 @@ use tracing::{debug, error}; use crate::{ config::{Configuration, SecretConfig}, store::IdCoopStore, + utils::{Clock, RandGen}, }; use super::ext_codes::{ @@ -53,12 +50,15 @@ pub struct TokenFormParams { /// OpenID Connect clients call this to exchange an authorisation code they received for an access token. /// /// TODO auth_header can be one alternative auth method +#[allow(clippy::too_many_arguments)] pub async fn oidc_token( basic_auth: Option>>, Extension(config): Extension>, Extension(secrets): Extension>, Extension(store): Extension>, Extension(code_store): Extension, + Extension(mut randgen): Extension, + Extension(clock): Extension, form: Result, FormRejection>, ) -> impl IntoResponse { let form = match form { @@ -110,8 +110,9 @@ pub async fn oidc_token( Json(TokenError { code: TokenErrorCode::InvalidClient, description: "That `client_id` is not recognised here.".to_string(), - }) - ).into_response(); + }), + ) + .into_response(); }; if !bool::from( @@ -181,10 +182,10 @@ pub async fn oidc_token( // Create an access token but don't actually issue it yet: // This lets us store the hash of the access token against the redemption of the auth code, // so double redemptions can invalidate the access token appropriately. - let access_token = thread_rng().gen::(); + let access_token = randgen.gen::(); let access_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(access_token); let access_token_hash: AccessTokenHash = Blake2s256::digest(access_token).into(); - let refresh_token = thread_rng().gen::(); + let refresh_token = randgen.gen::(); let refresh_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(refresh_token); let refresh_token_hash: RefreshTokenHash = Blake2s256::digest(refresh_token).into(); @@ -256,7 +257,9 @@ pub async fn oidc_token( } // 2. Check the code challenge - let Some(computed_code_challenge) = compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier) else { + let Some(computed_code_challenge) = + compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier) + else { return ( StatusCode::BAD_REQUEST, Json(TokenError { @@ -287,28 +290,29 @@ pub async fn oidc_token( let user_login_session_id = binding.user_login_session_id; // Issue access token for a new session + let clock2 = clock.clone(); match store .txn(move |mut txn| { Box::pin(async move { let Some(session_id) = txn .create_application_session(user_id, &application_id, user_login_session_id) .await - .context("create_application_session")? else { - return Ok(Err( - ( - StatusCode::BAD_REQUEST, - Json(TokenError { - code: TokenErrorCode::InvalidGrant, - description: "Auth code has expired or was not valid.".to_owned(), - }), - ).into_response() - )); + .context("create_application_session")? + else { + return Ok(Err(( + StatusCode::BAD_REQUEST, + Json(TokenError { + code: TokenErrorCode::InvalidGrant, + description: "Auth code has expired or was not valid.".to_owned(), + }), + ) + .into_response())); }; txn.issue_access_token( &access_token_hash, session_id, // TODO(expiry) Support custom expiry, not 100 years - Utc::now() + Duration::days(365 * 100), + clock2.now_utc() + Duration::days(365 * 100), ) .await .context("issue_access_token")?; @@ -316,7 +320,7 @@ pub async fn oidc_token( &refresh_token_hash, session_id, // TODO(expiry) Support custom expiry, not 100 years - Utc::now() + Duration::days(365 * 100), + clock2.now_utc() + Duration::days(365 * 100), ) .await .context("issue_refresh_token")?; @@ -344,10 +348,7 @@ pub async fn oidc_token( } } - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Before unix epoch?") - .as_secs(); + let now = clock.now_timestamp(); // TODO(expiry) Support custom expiry times (not just 100 years) let exp = now + 100 * 365 * 86400; let sub = binding.user_id.hyphenated().to_string(); @@ -385,7 +386,9 @@ pub async fn oidc_token( } fn make_id_token(id_token: IdToken, secrets: &SecretConfig) -> eyre::Result { - let Ok(serde_json::Value::Object(map)) = serde_json::to_value(id_token).context("failed to serialise ID Token content") else { + let Ok(serde_json::Value::Object(map)) = + serde_json::to_value(id_token).context("failed to serialise ID Token content") + else { bail!("ID Token not a map"); }; @@ -447,9 +450,9 @@ pub struct IdToken { /// Implementers MAY provide for some small leeway, usually no more than a few minutes, to account for clock skew. /// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. /// See RFC 3339 [RFC3339] for details regarding date/times in general and UTC in particular. - pub exp: u64, + pub exp: i64, /// REQUIRED. Time at which the JWT was issued. Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. - pub iat: u64, + pub iat: i64, /// Time when the End-User authentication occurred. /// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. /// When a max_age request is made or when auth_time is requested as an Essential Claim, then this Claim is REQUIRED; otherwise, its inclusion is OPTIONAL.