Tests for the full flow up to and including the token endpoint

Signed-off-by: Olivier 'reivilibre <olivier@librepush.net>
This commit is contained in:
Olivier 'reivilibre' 2024-07-07 09:56:44 +01:00
parent f2b0a64fb0
commit 13e6cd5361
13 changed files with 281 additions and 106 deletions

View File

@ -110,7 +110,7 @@ pub struct UserInfo {
/// A wrapper around a database transaction with some database methods on it.
pub struct IdCoopStoreTxn<'a, 'txn> {
txn: &'a mut Transaction<'txn, Postgres>,
pub(crate) txn: &'a mut Transaction<'txn, Postgres>,
}
impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
@ -341,7 +341,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
.await
.context("failed to lookup login session")?;
let Some(row) = row_opt else { return Ok(None); };
let Some(row) = row_opt else {
return Ok(None);
};
Ok(Some(LoginSession {
user_name: row.user_name,
@ -373,7 +375,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
.await
.context("failed to lookup application session")?;
let Some(row) = row_opt else { return Ok(None); };
let Some(row) = row_opt else {
return Ok(None);
};
Ok(Some(ApplicationSession {
application_session_id: row.session_id,

View File

@ -3,6 +3,7 @@ use std::sync::Arc;
use axum::Router;
use confique::{Config, Partial};
use josekit::jwk::alg::rsa::RsaKeyPair;
use metrics::atomics::AtomicU64;
use pgtemp::PgTempDB;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256StarStar;
@ -20,6 +21,7 @@ struct TestSystem {
web: Router,
config: Arc<Configuration>,
store: Arc<IdCoopStore>,
clock: Clock,
}
const RSA_KEY_PAIR_PEM: &[u8] = include_bytes!("tests/keypair.pem");
@ -85,13 +87,13 @@ async fn basic_system() -> TestSystem {
let config = Arc::new(config);
let store = Arc::new(store);
let clock = Clock::Fake();
let clock = Clock(Arc::new(AtomicU64::new(0)));
let randgen = RandGen(Xoshiro256StarStar::seed_from_u64(424242));
let router = make_router(
store.clone(),
config.clone(),
Arc::new(secrets),
clock,
clock.clone(),
randgen,
)
.await
@ -102,6 +104,7 @@ async fn basic_system() -> TestSystem {
web: router,
config,
store,
clock,
}
}

View File

@ -4,7 +4,7 @@ expression: "(headers, text)"
---
- content-length: "55"
content-type: text/plain; charset=utf-8
location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=challenging&code_challenge_method=S256&nonce=noncey"
location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M&code_challenge_method=S256&nonce=noncey"
set-cookie: __Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000
x-frame-options: DENY
- Logged in. Redirecting you back to what you were doing.

View File

@ -0,0 +1,8 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- content-length: "288"
content-type: text/html; charset=utf-8
x-frame-options: DENY
- "hi <u>robert</u>, consent to <u>AClient</u>? <form method='POST'><input type='hidden' name='xsrf' value='0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>"

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- content-length: "46"
content-type: text/plain; charset=utf-8
location: "http://aclient.example.com/redirect?code=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO&state=wombat&iss=http%3A%2F%2Fissuer.example.com"
x-frame-options: DENY
- Authorisation succeeded; redirecting you back.

View File

@ -0,0 +1,14 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, json)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "803"
content-type: application/json
- access_token: HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE
expires_in: 31536000
id_token: eyJ0eXAiOiJKV1QiLCJraWQiOiJ0aGVrZXkiLCJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwic3ViIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwIiwiYXVkIjoiYWNsaWVudCIsImV4cCI6MzE1MzYwMDAzMCwiaWF0IjozMCwiYXV0aF90aW1lIjowLCJub25jZSI6Im5vbmNleSJ9.QQEhDgAcF2vBg2J6ledDzk_4ks4GyquJgMSE4KUREtTUVZbLpfa52sro8lPiBnFPCOz_DkSfpm4OQq8429mwcqfoyS-uBjtgPq7eij7kOa3BTrb9eC8rScGuDX0wJ9XZV-v0f3dun_sYhvH3smLqPoTF4wxtgT5b_2SCmnuqL2cmKN-GFox4mjmdoPzQxhAyTKlj_HkGHQjkl-nP96-71QeM5KwyLQes_OWU2HSEt9uiemUsEr4pMPv-po7QkrU5p2sJc5udcaUAOtuV9tpt5qg8P9TPWYo4M1GbMsbTyWYDhsmtusNKB6N2srwZwB9QwgE4DxoeoqmKlJf0BGF4Bg
refresh_token: bn0vvGrpCiQDMaH7MRw7FgHuPqoEYel0IFGYv_Z6E7o
scope: openid
token_type: Bearer

View File

@ -7,6 +7,7 @@ use axum_test_helper::{TestClient, TestResponse};
use insta::assert_yaml_snapshot;
use maplit::btreemap;
use sqlx::types::Uuid;
use crate::{passwords::create_password_hash, store::CreateUser, tests::basic_system};
@ -21,10 +22,15 @@ async fn dump_resp_text(
.headers()
.clone()
.into_iter()
.map(|(k, v)| (k.unwrap().to_string(), v.to_str().unwrap().to_owned()))
// skip headers that are duplicate keys.
// We don't care about theme for our purposes.
.filter_map(|(k, v)| Some((k?.to_string(), v.to_str().unwrap().to_owned())))
.collect();
// Remove date because it's not stable across tests!
headers.remove("date");
// Remove vary because it has multiple values and we don't want to
// introduce instability into our tests by only allowing one through.
headers.remove("vary");
let text = resp.text().await;
eprintln!("=== Response for {req_name} ===");
eprintln!("Status: {status:?}");
@ -36,21 +42,20 @@ async fn dump_resp_text(
/// Tests the full flow...
#[tokio::test]
async fn test_todo() {
async fn test_full_flow() {
let sys = basic_system().await;
let uuid = Uuid::nil();
let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap();
let _: () = sys
.store
.txn(|mut txn| {
Box::pin(async move {
txn.create_user(CreateUser {
user_login_name: "robert".to_owned(),
password_hash: Some(pwhash),
locked: false,
})
.await
.unwrap();
sqlx::query(
"INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id",
).bind("robert").bind(uuid).bind(pwhash).bind(false)
.fetch_one(&mut **txn.txn)
.await.unwrap();
Ok(())
})
})
@ -59,22 +64,28 @@ async fn test_todo() {
let client = TestClient::new(sys.web);
///// These requests are on behalf of the user's browser /////
const CODE_VERIFIER: &str = "verifying";
// base64(sha256("verifying"))
const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M";
// 1. /auth request
const LOGIN_URL: &str = "/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3Dchallenging%26code_challenge_method%3DS256%26nonce%3Dnoncey";
let resp = client.get("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=challenging&code_challenge_method=S256&nonce=noncey").send().await;
let login_url = format!("/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3D{CODE_CHALLENGE}%26code_challenge_method%3DS256%26nonce%3Dnoncey");
let resp = client.get(&format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey")).send().await;
let (status, headers, _text) = dump_resp_text("1. /auth request", resp).await;
assert_eq!(status, 302);
assert_eq!(headers.get("location").unwrap(), LOGIN_URL);
assert_eq!(headers.get("location").unwrap(), &login_url);
// 2. /login request
let resp = client.get(LOGIN_URL).send().await;
let resp = client.get(&login_url).send().await;
let (status, headers, text) = dump_resp_text("2. /login request", resp).await;
assert_eq!(status, 200);
assert_yaml_snapshot!("2/login", (headers, text));
// 3. /login request with credentials
let resp = client
.post(LOGIN_URL)
.post(&login_url)
.form(&btreemap! {
"username" => "robert",
"password" => "secret",
@ -87,5 +98,58 @@ async fn test_todo() {
.await;
let (status, headers, text) = dump_resp_text("3. /login request with credentials", resp).await;
assert_eq!(status, 302);
let auth_loc = headers.get("location").unwrap().to_owned();
assert_yaml_snapshot!("3/login", (headers, text));
// 4. /auth request
let resp = client
.get(&auth_loc)
.header(
"Cookie",
"__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE",
)
.send()
.await;
let (status, headers, text) = dump_resp_text("4. GET /auth after login", resp).await;
assert_eq!(status, 200);
assert_yaml_snapshot!("4/auth", (headers, text));
sys.clock.set_time(30);
// 5. /auth request with confirmation
let resp = client
.post(&auth_loc)
.header(
"Cookie",
"__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE",
)
.form(&btreemap! {
"action" => "accept",
"xsrf" => "0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("5. POST /auth after confirmation", resp).await;
assert_eq!(status, 302);
// Note this snapshot includes a Location: header back to the client application
assert_yaml_snapshot!("5/auth", (headers, text));
///// At this point, we make requests on behalf of the client /////
// 6. /token request
let resp = client
.post("/oidc/token")
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
.form(&btreemap! {
"code" => "HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO",
"code_verifier" => CODE_VERIFIER,
"grant_type" => "authorization_code",
"redirect_uri" => "http://aclient.example.com/redirect",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("6. POST /token", resp).await;
assert_eq!(status, 200);
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
assert_yaml_snapshot!("6/token", (headers, json));
}

View File

@ -1,7 +1,16 @@
//! Miscellaneous utilities
#[cfg(not(test))]
pub use self::real_utils::{Clock, RandGen};
#[cfg(test)]
pub use self::test_utils::{Clock, RandGen};
#[cfg(not(test))]
mod real_utils {
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use chrono::{DateTime, Utc};
use rand::{thread_rng, RngCore};
/// A source of random numbers that can be faked for tests.
@ -25,12 +34,47 @@ mod real_utils {
thread_rng().try_fill_bytes(dest)
}
}
/// A source of time that can be faked for tests.
#[derive(Clone)]
pub struct Clock;
impl Clock {
/// Returns the current time as a `DateTime<Utc>`.
pub fn now_utc(&self) -> DateTime<Utc> {
Utc::now()
}
/// Returns the current time as a i64 of seconds since the Unix epoch.
pub fn now_timestamp(&self) -> i64 {
Utc::now().timestamp()
}
/// Sleep until the given timestamp.
pub async fn sleep_until(&self, until_ts: i64) {
if until_ts < 0 {
return;
}
let sleep_until = UNIX_EPOCH + Duration::from_secs(until_ts as u64);
let now = SystemTime::now();
let sleep_for = sleep_until
.duration_since(now)
.unwrap_or(Duration::from_secs(0));
tokio::time::sleep(sleep_for).await
}
}
}
#[cfg(test)]
mod test_utils {
use std::ops::{Deref, DerefMut};
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use std::{
ops::{Deref, DerefMut},
sync::Arc,
};
use chrono::{DateTime, TimeZone, Utc};
use rand_xoshiro::Xoshiro256StarStar;
#[derive(Clone)]
@ -48,19 +92,37 @@ mod test_utils {
&mut self.0
}
}
}
#[cfg(not(test))]
pub use self::real_utils::RandGen;
#[cfg(test)]
pub use self::test_utils::RandGen;
/// A source of time that can be faked for tests.
#[derive(Clone)]
pub enum Clock {
/// Use real time
Real,
/// Fake time for use in tests
Fake(),
#[derive(Clone)]
pub struct Clock(pub Arc<AtomicU64>);
impl Clock {
pub fn new_test() -> Self {
Clock(Arc::new(AtomicU64::new(0)))
}
pub fn set_time(&self, new: u64) {
self.0.store(new, std::sync::atomic::Ordering::Relaxed);
}
pub fn now_utc(&self) -> DateTime<Utc> {
Utc.timestamp_opt(self.0.load(std::sync::atomic::Ordering::Relaxed) as i64, 0)
.earliest()
.unwrap()
}
pub fn now_timestamp(&self) -> i64 {
self.0.load(std::sync::atomic::Ordering::Relaxed) as i64
}
pub async fn sleep_until(&self, until_ts: i64) {
if until_ts < 0 {
return;
}
// Wait for time to advance past the requested timestamp
// TODO write a better test sleep implementation
while self.now_timestamp() < until_ts {
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
}
}

View File

@ -125,7 +125,7 @@ pub(crate) async fn make_router(
.layer(Extension(Arc::new(PasswordHashInflightLimiter::new(1))))
.layer(client_ip_source.into_extension())
.layer(Extension(Arc::new(ratelimiters)))
.layer(Extension(VolatileCodeStore::default()))
.layer(Extension(VolatileCodeStore::new(clock.clone())))
.layer(Extension(clock))
.layer(Extension(randgen));
@ -144,7 +144,7 @@ pub async fn serve(
use eyre::Context;
use tracing::info;
let router = make_router(store, config, secrets, Clock::Real, RandGen).await?;
let router = make_router(store, config, secrets, Clock, RandGen).await?;
info!("Listening on {bind:?}");
axum::Server::try_bind(&bind)

View File

@ -21,7 +21,7 @@ use chrono::{DateTime, Duration, TimeZone, Utc};
use eyre::eyre;
use eyre::{bail, Context, ContextCompat};
use governor::Jitter;
use rand::{thread_rng, Rng};
use rand::Rng;
use serde::Deserialize;
use sqlx::types::Uuid;
use tokio::sync::Semaphore;

View File

@ -9,7 +9,6 @@ use axum::{
Extension, Form,
};
use chrono::Utc;
use eyre::{Context, ContextCompat};
use serde::{Deserialize, Serialize};
@ -17,6 +16,7 @@ use tracing::{error, warn};
use crate::{
config::{Configuration, OidcClientConfiguration},
utils::{Clock, RandGen},
web::{
login::LoginSession,
make_login_redirect,
@ -69,6 +69,8 @@ pub async fn oidc_authorisation(
login_session: Option<LoginSession>,
Extension(config): Extension<Arc<Configuration>>,
Extension(code_store): Extension<VolatileCodeStore>,
Extension(clock): Extension<Clock>,
Extension(mut randgen): Extension<RandGen>,
OriginalUri(uri): OriginalUri,
) -> Response {
let Query(query) = match query {
@ -109,7 +111,7 @@ pub async fn oidc_authorisation(
// If the application requires consent, then we should ask for that.
if !client_config.skip_consent {
return show_consent_page(login_session, client_config, &config).await;
return show_consent_page(login_session, client_config, Extension(clock), &config).await;
}
// No consent needed: process the authorisation.
@ -120,6 +122,8 @@ pub async fn oidc_authorisation(
client_id,
client_config,
&config,
&mut randgen,
&clock,
&code_store,
)
.await
@ -133,11 +137,14 @@ pub struct PostConsentForm {
}
/// `POST /oidc/auth`
#[allow(clippy::too_many_arguments)]
pub async fn post_oidc_authorisation_consent(
Query(query): Query<AuthorisationQuery>,
login_session: Option<LoginSession>,
Extension(config): Extension<Arc<Configuration>>,
Extension(code_store): Extension<VolatileCodeStore>,
Extension(clock): Extension<Clock>,
Extension(mut randgen): Extension<RandGen>,
OriginalUri(uri): OriginalUri,
Form(form): Form<PostConsentForm>,
) -> Response {
@ -152,11 +159,11 @@ pub async fn post_oidc_authorisation_consent(
};
if login_session
.validate_xsrf_token(&form.xsrf, Utc::now())
.validate_xsrf_token(&form.xsrf, clock.now_utc())
.is_err()
{
// XSRF token is not valid, so show the consent form again...
return show_consent_page(login_session, client_config, &config).await;
return show_consent_page(login_session, client_config, Extension(clock), &config).await;
}
match form.action.as_str() {
@ -167,6 +174,8 @@ pub async fn post_oidc_authorisation_consent(
client_id,
client_config,
&config,
&mut randgen,
&clock,
&code_store,
)
.await
@ -233,10 +242,11 @@ fn validate_authorisation_basics<'a>(
async fn show_consent_page(
login_session: LoginSession,
client_config: &OidcClientConfiguration,
Extension(clock): Extension<Clock>,
_config: &Configuration,
) -> Response {
let xsrf_token = login_session
.generate_xsrf_token(Utc::now())
.generate_xsrf_token(clock.now_utc())
.expect("must be able to create a XSRF token");
Html(format!(
"hi <u>{}</u>, consent to <u>{}</u>? <form method='POST'><input type='hidden' name='xsrf' value='{}'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>",
@ -253,12 +263,15 @@ async fn show_consent_page(
/// Preconditions:
/// - any required consent from the user has now been obtained
/// - query.request_uri has been validated as a safe redirect URI
#[allow(clippy::too_many_arguments)]
async fn process_authorisation(
query: AuthorisationQuery,
login_session: LoginSession,
client_id: String,
_client_config: &OidcClientConfiguration,
config: &Configuration,
randgen: &mut RandGen,
clock: &Clock,
code_store: &VolatileCodeStore,
) -> Response {
assert_eq!(
@ -271,7 +284,7 @@ async fn process_authorisation(
// Generate a 192-bit random code, which fits into exactly 32 base64 characters.
// This is an arbitrary choice left to us but I feel a 192-bit value is sufficiently random.
let code = AuthCode::generate_new_random();
let code = AuthCode::generate_new_random(randgen);
let code_base64url = code.to_string();
// Write down the code and other details in-memory with 10 minute expiry...
@ -289,7 +302,7 @@ async fn process_authorisation(
user_id: login_session.user_id,
user_login_session_id: login_session.login_session_id,
},
0,
clock.now_timestamp() + 600,
);
#[derive(Serialize)]

View File

@ -6,7 +6,6 @@ use std::{
fmt::Display,
str::FromStr,
sync::{Arc, Mutex},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use base64::{display::Base64Display, prelude::BASE64_URL_SAFE_NO_PAD, Engine};
@ -14,6 +13,8 @@ use rand::Rng;
use sqlx::types::Uuid;
use tokio::sync::Notify;
use crate::utils::{Clock, RandGen};
/// Display shows the auth code as base64 (URL-safe non-padded).
/// FromStr/parse parses the same format.
#[derive(Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
@ -61,8 +62,8 @@ impl FromStr for AuthCode {
impl AuthCode {
/// Generate a new authorisation code using the thread's RNG
pub fn generate_new_random() -> Self {
Self(rand::thread_rng().gen::<[u8; 24]>())
pub fn generate_new_random(randgen: &mut RandGen) -> Self {
Self(randgen.gen::<[u8; 24]>())
}
}
@ -112,7 +113,7 @@ struct VolatileCodeStoreInner {
pub conflictable_codes: HashMap<AuthCode, RedeemedAuthCode>,
/// Time when codes will expire
pub expire_codes_at: BTreeSet<(u64, AuthCode)>,
pub expire_codes_at: BTreeSet<(i64, AuthCode)>,
}
impl VolatileCodeStoreInner {
@ -148,7 +149,7 @@ impl VolatileCodeStoreInner {
&mut self,
auth_code: AuthCode,
auth_code_binding: AuthCodeBinding,
expires_at: u64,
expires_at: i64,
) {
self.redeemable_codes
.insert(auth_code.clone(), auth_code_binding);
@ -156,7 +157,7 @@ impl VolatileCodeStoreInner {
}
/// Removes all expired auth codes and returns the time of the earliest next expiry, if present.
pub(self) fn handle_expiry(&mut self, now: u64) -> Option<u64> {
pub(self) fn handle_expiry(&mut self, now: i64) -> Option<i64> {
loop {
let (ts, _auth_code) = self.expire_codes_at.first()?;
@ -182,15 +183,16 @@ pub struct VolatileCodeStore {
inner: Arc<Mutex<VolatileCodeStoreInner>>,
}
impl Default for VolatileCodeStore {
fn default() -> Self {
impl VolatileCodeStore {
/// Create a new instance.
pub fn new(clock: Clock) -> Self {
let poke = Arc::new(Notify::new());
let inner: Arc<Mutex<VolatileCodeStoreInner>> = Default::default();
{
let poke = poke.clone();
let inner = inner.clone();
tokio::spawn(Self::expirer(inner, poke));
tokio::spawn(Self::expirer(inner, poke, clock));
}
VolatileCodeStore { inner, poke }
@ -198,19 +200,15 @@ impl Default for VolatileCodeStore {
}
impl VolatileCodeStore {
async fn expirer(inner: Arc<Mutex<VolatileCodeStoreInner>>, poke: Arc<Notify>) {
let mut next_expiry: Option<u64> = None;
async fn expirer(inner: Arc<Mutex<VolatileCodeStoreInner>>, poke: Arc<Notify>, clock: Clock) {
let mut next_expiry: Option<i64> = None;
loop {
match next_expiry {
Some(next_expiry) => {
let sleep_until = UNIX_EPOCH + Duration::from_secs(next_expiry);
let now = SystemTime::now();
let sleep_for = sleep_until
.duration_since(now)
.unwrap_or(Duration::from_secs(60));
let sleep_future = clock.sleep_until(next_expiry);
tokio::select! {
_ = poke.notified() => {},
_ = tokio::time::sleep(sleep_for) => {},
_ = sleep_future => {},
}
}
None => {
@ -218,14 +216,9 @@ impl VolatileCodeStore {
}
}
let now = SystemTime::now();
next_expiry = {
let mut inner = inner.lock().unwrap();
inner.handle_expiry(
now.duration_since(UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_secs(),
)
inner.handle_expiry(clock.now_timestamp())
};
}
}
@ -244,7 +237,7 @@ impl VolatileCodeStore {
}
/// Add a new redeemable authorisation code.
pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: u64) {
pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: i64) {
let mut inner = self.inner.lock().unwrap();
inner.add_redeemable(auth_code, binding, expires_at);
drop(inner);
@ -284,7 +277,7 @@ mod test {
use rstest::{fixture, rstest};
use sqlx::types::Uuid;
use crate::web::oauth_openid::ext_codes::CodeRedemption;
use crate::{utils::Clock, web::oauth_openid::ext_codes::CodeRedemption};
use super::{AuthCode, AuthCodeBinding, VolatileCodeStore, VolatileCodeStoreInner};
@ -422,7 +415,8 @@ mod test {
/// We can't easily cover everything here but may as well test the basics.
#[tokio::test]
async fn test_fullfat_store_basic() {
let vcs = VolatileCodeStore::default();
let clock = Clock::new_test();
let vcs = VolatileCodeStore::new(clock.clone());
vcs.add_redeemable(
VALID_CODE.clone(),
@ -435,7 +429,7 @@ mod test {
user_id: USER_UUID,
user_login_session_id: USER_LOGIN_SESSION_ID,
},
u64::MAX,
i64::MAX,
);
assert_matches!(
@ -458,6 +452,7 @@ mod test {
// given a moment
1,
);
clock.set_time(2);
vcs.add_redeemable(
AuthCode([0; 24]),
@ -470,11 +465,11 @@ mod test {
user_id: USER_UUID,
user_login_session_id: USER_LOGIN_SESSION_ID,
},
u64::MAX,
i64::MAX,
);
// Give a short time for the expiry to take place.
tokio::time::sleep(Duration::from_millis(1)).await;
tokio::time::sleep(Duration::from_millis(2)).await;
assert_matches!(
vcs.redeem(&VALID_CODE, [1; 32], [2; 32]),

View File

@ -1,10 +1,6 @@
//! `/oidc/token`
use std::{
str::FromStr,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use std::{str::FromStr, sync::Arc};
use axum::{
extract::rejection::FormRejection,
@ -15,13 +11,13 @@ use axum::{
};
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
use blake2::Blake2s256;
use chrono::{Duration, Utc};
use chrono::Duration;
use eyre::{bail, Context};
use josekit::{
jws::{alg::rsassa::RsassaJwsAlgorithm::Rs256, JwsHeader},
jwt::JwtPayload,
};
use rand::{thread_rng, Rng};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
@ -30,6 +26,7 @@ use tracing::{debug, error};
use crate::{
config::{Configuration, SecretConfig},
store::IdCoopStore,
utils::{Clock, RandGen},
};
use super::ext_codes::{
@ -53,12 +50,15 @@ pub struct TokenFormParams {
/// OpenID Connect clients call this to exchange an authorisation code they received for an access token.
///
/// TODO auth_header can be one alternative auth method
#[allow(clippy::too_many_arguments)]
pub async fn oidc_token(
basic_auth: Option<TypedHeader<Authorization<Basic>>>,
Extension(config): Extension<Arc<Configuration>>,
Extension(secrets): Extension<Arc<SecretConfig>>,
Extension(store): Extension<Arc<IdCoopStore>>,
Extension(code_store): Extension<VolatileCodeStore>,
Extension(mut randgen): Extension<RandGen>,
Extension(clock): Extension<Clock>,
form: Result<Form<TokenFormParams>, FormRejection>,
) -> impl IntoResponse {
let form = match form {
@ -110,8 +110,9 @@ pub async fn oidc_token(
Json(TokenError {
code: TokenErrorCode::InvalidClient,
description: "That `client_id` is not recognised here.".to_string(),
})
).into_response();
}),
)
.into_response();
};
if !bool::from(
@ -181,10 +182,10 @@ pub async fn oidc_token(
// Create an access token but don't actually issue it yet:
// This lets us store the hash of the access token against the redemption of the auth code,
// so double redemptions can invalidate the access token appropriately.
let access_token = thread_rng().gen::<AccessToken>();
let access_token = randgen.gen::<AccessToken>();
let access_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(access_token);
let access_token_hash: AccessTokenHash = Blake2s256::digest(access_token).into();
let refresh_token = thread_rng().gen::<RefreshToken>();
let refresh_token = randgen.gen::<RefreshToken>();
let refresh_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(refresh_token);
let refresh_token_hash: RefreshTokenHash = Blake2s256::digest(refresh_token).into();
@ -256,7 +257,9 @@ pub async fn oidc_token(
}
// 2. Check the code challenge
let Some(computed_code_challenge) = compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier) else {
let Some(computed_code_challenge) =
compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier)
else {
return (
StatusCode::BAD_REQUEST,
Json(TokenError {
@ -287,28 +290,29 @@ pub async fn oidc_token(
let user_login_session_id = binding.user_login_session_id;
// Issue access token for a new session
let clock2 = clock.clone();
match store
.txn(move |mut txn| {
Box::pin(async move {
let Some(session_id) = txn
.create_application_session(user_id, &application_id, user_login_session_id)
.await
.context("create_application_session")? else {
return Ok(Err(
(
StatusCode::BAD_REQUEST,
Json(TokenError {
code: TokenErrorCode::InvalidGrant,
description: "Auth code has expired or was not valid.".to_owned(),
}),
).into_response()
));
.context("create_application_session")?
else {
return Ok(Err((
StatusCode::BAD_REQUEST,
Json(TokenError {
code: TokenErrorCode::InvalidGrant,
description: "Auth code has expired or was not valid.".to_owned(),
}),
)
.into_response()));
};
txn.issue_access_token(
&access_token_hash,
session_id,
// TODO(expiry) Support custom expiry, not 100 years
Utc::now() + Duration::days(365 * 100),
clock2.now_utc() + Duration::days(365 * 100),
)
.await
.context("issue_access_token")?;
@ -316,7 +320,7 @@ pub async fn oidc_token(
&refresh_token_hash,
session_id,
// TODO(expiry) Support custom expiry, not 100 years
Utc::now() + Duration::days(365 * 100),
clock2.now_utc() + Duration::days(365 * 100),
)
.await
.context("issue_refresh_token")?;
@ -344,10 +348,7 @@ pub async fn oidc_token(
}
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Before unix epoch?")
.as_secs();
let now = clock.now_timestamp();
// TODO(expiry) Support custom expiry times (not just 100 years)
let exp = now + 100 * 365 * 86400;
let sub = binding.user_id.hyphenated().to_string();
@ -385,7 +386,9 @@ pub async fn oidc_token(
}
fn make_id_token(id_token: IdToken, secrets: &SecretConfig) -> eyre::Result<String> {
let Ok(serde_json::Value::Object(map)) = serde_json::to_value(id_token).context("failed to serialise ID Token content") else {
let Ok(serde_json::Value::Object(map)) =
serde_json::to_value(id_token).context("failed to serialise ID Token content")
else {
bail!("ID Token not a map");
};
@ -447,9 +450,9 @@ pub struct IdToken {
/// Implementers MAY provide for some small leeway, usually no more than a few minutes, to account for clock skew.
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
/// See RFC 3339 [RFC3339] for details regarding date/times in general and UTC in particular.
pub exp: u64,
pub exp: i64,
/// REQUIRED. Time at which the JWT was issued. Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
pub iat: u64,
pub iat: i64,
/// Time when the End-User authentication occurred.
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
/// When a max_age request is made or when auth_time is requested as an Essential Claim, then this Claim is REQUIRED; otherwise, its inclusion is OPTIONAL.