From 396539774877061d9d439dbf95b5cc5b9f86640c Mon Sep 17 00:00:00 2001 From: Olivier 'reivilibre Date: Sat, 6 Jul 2024 14:31:35 +0100 Subject: [PATCH] Tests for auth codes Signed-off-by: Olivier 'reivilibre --- Cargo.lock | 7 + Cargo.toml | 1 + src/web/oauth_openid/ext_codes.rs | 218 +++++++++++++++++++++++++++++- 3 files changed, 225 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index cf79d9b..23e967e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -152,6 +152,12 @@ dependencies = [ "password-hash", ] +[[package]] +name = "assert_matches2" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15832d94c458da98cac0ffa6eca52cc19c2a3c6c951058500a5ae8f01f0fdf56" + [[package]] name = "async-recursion" version = "1.0.5" @@ -1633,6 +1639,7 @@ name = "idcoop" version = "0.0.1" dependencies = [ "argon2", + "assert_matches2", "async-trait", "axum", "axum-client-ip", diff --git a/Cargo.toml b/Cargo.toml index 179952c..2e817f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } [dev-dependencies] +assert_matches2 = "0.1.2" axum-test-helper = "0.3.0" insta = { version = "1.39.0", features = ["serde", "yaml"] } pgtemp = "0.3.0" diff --git a/src/web/oauth_openid/ext_codes.rs b/src/web/oauth_openid/ext_codes.rs index 6b55b56..35a1bba 100644 --- a/src/web/oauth_openid/ext_codes.rs +++ b/src/web/oauth_openid/ext_codes.rs @@ -2,6 +2,7 @@ use std::{ collections::{BTreeSet, HashMap}, + fmt::Debug, fmt::Display, str::FromStr, sync::{Arc, Mutex}, @@ -16,7 +17,7 @@ use tokio::sync::Notify; /// Display shows the auth code as base64 (URL-safe non-padded). /// FromStr/parse parses the same format. #[derive(Clone, Hash, PartialEq, Eq, Ord, PartialOrd)] -pub struct AuthCode([u8; 24]); +pub struct AuthCode(pub [u8; 24]); /// Access token pub type AccessToken = [u8; 32]; @@ -37,6 +38,12 @@ impl Display for AuthCode { } } +impl Debug for AuthCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + impl FromStr for AuthCode { type Err = &'static str; @@ -62,6 +69,7 @@ impl AuthCode { /// Binding between an authorisation code (ready to be redeemed) /// and both the user that authenticated to produce it /// as well as the OpenID Connect client that it is for. +#[derive(PartialEq, Eq, Debug)] pub struct AuthCodeBinding { /// ID of the OpenID Connect client pub client_id: String, @@ -246,6 +254,7 @@ impl VolatileCodeStore { } /// Possible outcomes of an attempt to redeem an authorisation code. +#[derive(PartialEq, Eq, Debug)] pub enum CodeRedemption { /// That auth code was not active Invalid, @@ -266,3 +275,210 @@ pub enum CodeRedemption { refresh_token_to_invalidate: RefreshTokenHash, }, } + +#[cfg(test)] +mod test { + use std::{str::FromStr, time::Duration}; + + use assert_matches2::assert_matches; + use rstest::{fixture, rstest}; + use sqlx::types::Uuid; + + use crate::web::oauth_openid::ext_codes::CodeRedemption; + + use super::{AuthCode, AuthCodeBinding, VolatileCodeStore, VolatileCodeStoreInner}; + + const VALID_CODE: AuthCode = AuthCode([21; 24]); + + const USER_UUID: Uuid = Uuid::nil(); + const USER_LOGIN_SESSION_ID: i32 = 1347; + + #[fixture] + fn code_store() -> VolatileCodeStoreInner { + let mut vcs = VolatileCodeStoreInner::default(); + + vcs.add_redeemable( + VALID_CODE.clone(), + AuthCodeBinding { + client_id: "client_id".to_owned(), + redirect_uri: "https://client/redirect".to_owned(), + nonce: None, + code_challenge_method: "method".to_owned(), + code_challenge: "challenge".to_owned(), + user_id: USER_UUID, + user_login_session_id: USER_LOGIN_SESSION_ID, + }, + 128, + ); + vcs + } + + #[rstest] + fn test_redeem_nonexistent_code(mut code_store: VolatileCodeStoreInner) { + let ac = AuthCode([0; 24]); + + let result = code_store.redeem(&ac, [42; 32], [43; 32]); + + assert_eq!(result, CodeRedemption::Invalid); + } + + #[rstest] + fn test_redeem_real_code(mut code_store: VolatileCodeStoreInner) { + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + + assert_matches!(result, CodeRedemption::Valid { binding }); + + assert_eq!(binding.client_id, "client_id"); + } + + #[rstest] + fn test_redeem_code_twice(mut code_store: VolatileCodeStoreInner) { + // redeem a first time + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_matches!(result, CodeRedemption::Valid { .. }); + + // redeem a second time + let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]); + assert_matches!( + result, + CodeRedemption::Conflicted { + access_token_to_invalidate, + refresh_token_to_invalidate + } + ); + + assert_eq!(access_token_to_invalidate, [42; 32]); + assert_eq!(refresh_token_to_invalidate, [43; 32]); + } + + #[rstest] + fn test_expire_and_redeem_not_expired_yet(mut code_store: VolatileCodeStoreInner) { + code_store.handle_expiry(127); + + // The code shouldn't expire yet. + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_matches!(result, CodeRedemption::Valid { .. }); + } + + #[rstest] + fn test_expire_and_redeem_expired_now(mut code_store: VolatileCodeStoreInner) { + code_store.handle_expiry(128); + + // The code should have expired + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_eq!(result, CodeRedemption::Invalid); + } + + #[rstest] + fn test_expire_and_redeem_conflict(mut code_store: VolatileCodeStoreInner) { + // redeem a first time + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_matches!(result, CodeRedemption::Valid { .. }); + + code_store.handle_expiry(127); + + // redeem a second time: the conflict should not have expired yet. + let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]); + assert_matches!(result, CodeRedemption::Conflicted { .. }); + + // redeem a third time to show the conflict doesn't get removed. + let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]); + assert_matches!(result, CodeRedemption::Conflicted { .. }); + + code_store.handle_expiry(128); + + // The code and its conflict entry should have expired + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_eq!(result, CodeRedemption::Invalid); + } + + #[test] + fn test_authcode_display() { + assert_eq!(VALID_CODE.to_string(), "FRUVFRUVFRUVFRUVFRUVFRUVFRUVFRUV"); + assert_eq!( + format!("{VALID_CODE:?}"), + "FRUVFRUVFRUVFRUVFRUVFRUVFRUVFRUV" + ); + } + + #[test] + fn test_authcode_from_string() { + assert_eq!( + AuthCode::from_str(&VALID_CODE.to_string()).unwrap(), + VALID_CODE + ); + } + + #[test] + fn test_authcode_from_string_wrong_size() { + // Not a valid base64 string + assert_eq!(AuthCode::from_str("a").unwrap_err(), "wrong size"); + + // Not 24 bytes when decoded + assert_eq!(AuthCode::from_str("abcd").unwrap_err(), "wrong size"); + } + + /// Basic test for the full-fat VolatileCodeStore, not just its inner logic + /// We can't easily cover everything here but may as well test the basics. + #[tokio::test] + async fn test_fullfat_store_basic() { + let vcs = VolatileCodeStore::default(); + + vcs.add_redeemable( + VALID_CODE.clone(), + AuthCodeBinding { + client_id: "client_id".to_owned(), + redirect_uri: "https://client/redirect".to_owned(), + nonce: None, + code_challenge_method: "method".to_owned(), + code_challenge: "challenge".to_owned(), + user_id: USER_UUID, + user_login_session_id: USER_LOGIN_SESSION_ID, + }, + u64::MAX, + ); + + assert_matches!( + vcs.redeem(&VALID_CODE, [1; 32], [2; 32]), + CodeRedemption::Valid { .. } + ); + + vcs.add_redeemable( + VALID_CODE.clone(), + AuthCodeBinding { + client_id: "client_id".to_owned(), + redirect_uri: "https://client/redirect".to_owned(), + nonce: None, + code_challenge_method: "method".to_owned(), + code_challenge: "challenge".to_owned(), + user_id: USER_UUID, + user_login_session_id: USER_LOGIN_SESSION_ID, + }, + // Expiry in the past: this should be auto-expired when poked and + // given a moment + 1, + ); + + vcs.add_redeemable( + AuthCode([0; 24]), + AuthCodeBinding { + client_id: "client_id".to_owned(), + redirect_uri: "https://client/redirect".to_owned(), + nonce: None, + code_challenge_method: "method".to_owned(), + code_challenge: "challenge".to_owned(), + user_id: USER_UUID, + user_login_session_id: USER_LOGIN_SESSION_ID, + }, + u64::MAX, + ); + + // Give a short time for the expiry to take place. + tokio::time::sleep(Duration::from_millis(1)).await; + + assert_matches!( + vcs.redeem(&VALID_CODE, [1; 32], [2; 32]), + CodeRedemption::Invalid + ); + } +}