Tests for the full flow up to and including the token endpoint
Signed-off-by: Olivier 'reivilibre <olivier@librepush.net>
This commit is contained in:
parent
f2b0a64fb0
commit
13e6cd5361
10
src/store.rs
10
src/store.rs
@ -110,7 +110,7 @@ pub struct UserInfo {
|
||||
|
||||
/// A wrapper around a database transaction with some database methods on it.
|
||||
pub struct IdCoopStoreTxn<'a, 'txn> {
|
||||
txn: &'a mut Transaction<'txn, Postgres>,
|
||||
pub(crate) txn: &'a mut Transaction<'txn, Postgres>,
|
||||
}
|
||||
|
||||
impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
|
||||
@ -341,7 +341,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
|
||||
.await
|
||||
.context("failed to lookup login session")?;
|
||||
|
||||
let Some(row) = row_opt else { return Ok(None); };
|
||||
let Some(row) = row_opt else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some(LoginSession {
|
||||
user_name: row.user_name,
|
||||
@ -373,7 +375,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
|
||||
.await
|
||||
.context("failed to lookup application session")?;
|
||||
|
||||
let Some(row) = row_opt else { return Ok(None); };
|
||||
let Some(row) = row_opt else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Ok(Some(ApplicationSession {
|
||||
application_session_id: row.session_id,
|
||||
|
||||
@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
use axum::Router;
|
||||
use confique::{Config, Partial};
|
||||
use josekit::jwk::alg::rsa::RsaKeyPair;
|
||||
use metrics::atomics::AtomicU64;
|
||||
use pgtemp::PgTempDB;
|
||||
use rand::SeedableRng;
|
||||
use rand_xoshiro::Xoshiro256StarStar;
|
||||
@ -20,6 +21,7 @@ struct TestSystem {
|
||||
web: Router,
|
||||
config: Arc<Configuration>,
|
||||
store: Arc<IdCoopStore>,
|
||||
clock: Clock,
|
||||
}
|
||||
|
||||
const RSA_KEY_PAIR_PEM: &[u8] = include_bytes!("tests/keypair.pem");
|
||||
@ -85,13 +87,13 @@ async fn basic_system() -> TestSystem {
|
||||
|
||||
let config = Arc::new(config);
|
||||
let store = Arc::new(store);
|
||||
let clock = Clock::Fake();
|
||||
let clock = Clock(Arc::new(AtomicU64::new(0)));
|
||||
let randgen = RandGen(Xoshiro256StarStar::seed_from_u64(424242));
|
||||
let router = make_router(
|
||||
store.clone(),
|
||||
config.clone(),
|
||||
Arc::new(secrets),
|
||||
clock,
|
||||
clock.clone(),
|
||||
randgen,
|
||||
)
|
||||
.await
|
||||
@ -102,6 +104,7 @@ async fn basic_system() -> TestSystem {
|
||||
web: router,
|
||||
config,
|
||||
store,
|
||||
clock,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ expression: "(headers, text)"
|
||||
---
|
||||
- content-length: "55"
|
||||
content-type: text/plain; charset=utf-8
|
||||
location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=challenging&code_challenge_method=S256&nonce=noncey"
|
||||
location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M&code_challenge_method=S256&nonce=noncey"
|
||||
set-cookie: __Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000
|
||||
x-frame-options: DENY
|
||||
- Logged in. Redirecting you back to what you were doing.
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
---
|
||||
source: src/tests/test_oidc_auth_flow.rs
|
||||
expression: "(headers, text)"
|
||||
---
|
||||
- content-length: "288"
|
||||
content-type: text/html; charset=utf-8
|
||||
x-frame-options: DENY
|
||||
- "hi <u>robert</u>, consent to <u>AClient</u>? <form method='POST'><input type='hidden' name='xsrf' value='0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>"
|
||||
@ -0,0 +1,9 @@
|
||||
---
|
||||
source: src/tests/test_oidc_auth_flow.rs
|
||||
expression: "(headers, text)"
|
||||
---
|
||||
- content-length: "46"
|
||||
content-type: text/plain; charset=utf-8
|
||||
location: "http://aclient.example.com/redirect?code=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO&state=wombat&iss=http%3A%2F%2Fissuer.example.com"
|
||||
x-frame-options: DENY
|
||||
- Authorisation succeeded; redirecting you back.
|
||||
@ -0,0 +1,14 @@
|
||||
---
|
||||
source: src/tests/test_oidc_auth_flow.rs
|
||||
expression: "(headers, json)"
|
||||
---
|
||||
- access-control-allow-origin: "*"
|
||||
access-control-expose-headers: "*"
|
||||
content-length: "803"
|
||||
content-type: application/json
|
||||
- access_token: HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE
|
||||
expires_in: 31536000
|
||||
id_token: eyJ0eXAiOiJKV1QiLCJraWQiOiJ0aGVrZXkiLCJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwic3ViIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwIiwiYXVkIjoiYWNsaWVudCIsImV4cCI6MzE1MzYwMDAzMCwiaWF0IjozMCwiYXV0aF90aW1lIjowLCJub25jZSI6Im5vbmNleSJ9.QQEhDgAcF2vBg2J6ledDzk_4ks4GyquJgMSE4KUREtTUVZbLpfa52sro8lPiBnFPCOz_DkSfpm4OQq8429mwcqfoyS-uBjtgPq7eij7kOa3BTrb9eC8rScGuDX0wJ9XZV-v0f3dun_sYhvH3smLqPoTF4wxtgT5b_2SCmnuqL2cmKN-GFox4mjmdoPzQxhAyTKlj_HkGHQjkl-nP96-71QeM5KwyLQes_OWU2HSEt9uiemUsEr4pMPv-po7QkrU5p2sJc5udcaUAOtuV9tpt5qg8P9TPWYo4M1GbMsbTyWYDhsmtusNKB6N2srwZwB9QwgE4DxoeoqmKlJf0BGF4Bg
|
||||
refresh_token: bn0vvGrpCiQDMaH7MRw7FgHuPqoEYel0IFGYv_Z6E7o
|
||||
scope: openid
|
||||
token_type: Bearer
|
||||
@ -7,6 +7,7 @@ use axum_test_helper::{TestClient, TestResponse};
|
||||
use insta::assert_yaml_snapshot;
|
||||
|
||||
use maplit::btreemap;
|
||||
use sqlx::types::Uuid;
|
||||
|
||||
use crate::{passwords::create_password_hash, store::CreateUser, tests::basic_system};
|
||||
|
||||
@ -21,10 +22,15 @@ async fn dump_resp_text(
|
||||
.headers()
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.unwrap().to_string(), v.to_str().unwrap().to_owned()))
|
||||
// skip headers that are duplicate keys.
|
||||
// We don't care about theme for our purposes.
|
||||
.filter_map(|(k, v)| Some((k?.to_string(), v.to_str().unwrap().to_owned())))
|
||||
.collect();
|
||||
// Remove date because it's not stable across tests!
|
||||
headers.remove("date");
|
||||
// Remove vary because it has multiple values and we don't want to
|
||||
// introduce instability into our tests by only allowing one through.
|
||||
headers.remove("vary");
|
||||
let text = resp.text().await;
|
||||
eprintln!("=== Response for {req_name} ===");
|
||||
eprintln!("Status: {status:?}");
|
||||
@ -36,21 +42,20 @@ async fn dump_resp_text(
|
||||
|
||||
/// Tests the full flow...
|
||||
#[tokio::test]
|
||||
async fn test_todo() {
|
||||
async fn test_full_flow() {
|
||||
let sys = basic_system().await;
|
||||
|
||||
let uuid = Uuid::nil();
|
||||
let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap();
|
||||
let _: () = sys
|
||||
.store
|
||||
.txn(|mut txn| {
|
||||
Box::pin(async move {
|
||||
txn.create_user(CreateUser {
|
||||
user_login_name: "robert".to_owned(),
|
||||
password_hash: Some(pwhash),
|
||||
locked: false,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
sqlx::query(
|
||||
"INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id",
|
||||
).bind("robert").bind(uuid).bind(pwhash).bind(false)
|
||||
.fetch_one(&mut **txn.txn)
|
||||
.await.unwrap();
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
@ -59,22 +64,28 @@ async fn test_todo() {
|
||||
|
||||
let client = TestClient::new(sys.web);
|
||||
|
||||
///// These requests are on behalf of the user's browser /////
|
||||
|
||||
const CODE_VERIFIER: &str = "verifying";
|
||||
// base64(sha256("verifying"))
|
||||
const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M";
|
||||
|
||||
// 1. /auth request
|
||||
const LOGIN_URL: &str = "/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3Dchallenging%26code_challenge_method%3DS256%26nonce%3Dnoncey";
|
||||
let resp = client.get("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=challenging&code_challenge_method=S256&nonce=noncey").send().await;
|
||||
let login_url = format!("/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3D{CODE_CHALLENGE}%26code_challenge_method%3DS256%26nonce%3Dnoncey");
|
||||
let resp = client.get(&format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey")).send().await;
|
||||
let (status, headers, _text) = dump_resp_text("1. /auth request", resp).await;
|
||||
assert_eq!(status, 302);
|
||||
assert_eq!(headers.get("location").unwrap(), LOGIN_URL);
|
||||
assert_eq!(headers.get("location").unwrap(), &login_url);
|
||||
|
||||
// 2. /login request
|
||||
let resp = client.get(LOGIN_URL).send().await;
|
||||
let resp = client.get(&login_url).send().await;
|
||||
let (status, headers, text) = dump_resp_text("2. /login request", resp).await;
|
||||
assert_eq!(status, 200);
|
||||
assert_yaml_snapshot!("2/login", (headers, text));
|
||||
|
||||
// 3. /login request with credentials
|
||||
let resp = client
|
||||
.post(LOGIN_URL)
|
||||
.post(&login_url)
|
||||
.form(&btreemap! {
|
||||
"username" => "robert",
|
||||
"password" => "secret",
|
||||
@ -87,5 +98,58 @@ async fn test_todo() {
|
||||
.await;
|
||||
let (status, headers, text) = dump_resp_text("3. /login request with credentials", resp).await;
|
||||
assert_eq!(status, 302);
|
||||
let auth_loc = headers.get("location").unwrap().to_owned();
|
||||
assert_yaml_snapshot!("3/login", (headers, text));
|
||||
|
||||
// 4. /auth request
|
||||
let resp = client
|
||||
.get(&auth_loc)
|
||||
.header(
|
||||
"Cookie",
|
||||
"__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE",
|
||||
)
|
||||
.send()
|
||||
.await;
|
||||
let (status, headers, text) = dump_resp_text("4. GET /auth after login", resp).await;
|
||||
assert_eq!(status, 200);
|
||||
assert_yaml_snapshot!("4/auth", (headers, text));
|
||||
|
||||
sys.clock.set_time(30);
|
||||
|
||||
// 5. /auth request with confirmation
|
||||
let resp = client
|
||||
.post(&auth_loc)
|
||||
.header(
|
||||
"Cookie",
|
||||
"__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE",
|
||||
)
|
||||
.form(&btreemap! {
|
||||
"action" => "accept",
|
||||
"xsrf" => "0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y",
|
||||
})
|
||||
.send()
|
||||
.await;
|
||||
let (status, headers, text) = dump_resp_text("5. POST /auth after confirmation", resp).await;
|
||||
assert_eq!(status, 302);
|
||||
// Note this snapshot includes a Location: header back to the client application
|
||||
assert_yaml_snapshot!("5/auth", (headers, text));
|
||||
|
||||
///// At this point, we make requests on behalf of the client /////
|
||||
|
||||
// 6. /token request
|
||||
let resp = client
|
||||
.post("/oidc/token")
|
||||
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
|
||||
.form(&btreemap! {
|
||||
"code" => "HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO",
|
||||
"code_verifier" => CODE_VERIFIER,
|
||||
"grant_type" => "authorization_code",
|
||||
"redirect_uri" => "http://aclient.example.com/redirect",
|
||||
})
|
||||
.send()
|
||||
.await;
|
||||
let (status, headers, text) = dump_resp_text("6. POST /token", resp).await;
|
||||
assert_eq!(status, 200);
|
||||
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
|
||||
assert_yaml_snapshot!("6/token", (headers, json));
|
||||
}
|
||||
|
||||
94
src/utils.rs
94
src/utils.rs
@ -1,7 +1,16 @@
|
||||
//! Miscellaneous utilities
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub use self::real_utils::{Clock, RandGen};
|
||||
|
||||
#[cfg(test)]
|
||||
pub use self::test_utils::{Clock, RandGen};
|
||||
|
||||
#[cfg(not(test))]
|
||||
mod real_utils {
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
/// A source of random numbers that can be faked for tests.
|
||||
@ -25,12 +34,47 @@ mod real_utils {
|
||||
thread_rng().try_fill_bytes(dest)
|
||||
}
|
||||
}
|
||||
|
||||
/// A source of time that can be faked for tests.
|
||||
#[derive(Clone)]
|
||||
pub struct Clock;
|
||||
|
||||
impl Clock {
|
||||
/// Returns the current time as a `DateTime<Utc>`.
|
||||
pub fn now_utc(&self) -> DateTime<Utc> {
|
||||
Utc::now()
|
||||
}
|
||||
|
||||
/// Returns the current time as a i64 of seconds since the Unix epoch.
|
||||
pub fn now_timestamp(&self) -> i64 {
|
||||
Utc::now().timestamp()
|
||||
}
|
||||
|
||||
/// Sleep until the given timestamp.
|
||||
pub async fn sleep_until(&self, until_ts: i64) {
|
||||
if until_ts < 0 {
|
||||
return;
|
||||
}
|
||||
let sleep_until = UNIX_EPOCH + Duration::from_secs(until_ts as u64);
|
||||
let now = SystemTime::now();
|
||||
let sleep_for = sleep_until
|
||||
.duration_since(now)
|
||||
.unwrap_or(Duration::from_secs(0));
|
||||
tokio::time::sleep(sleep_for).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_utils {
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
ops::{Deref, DerefMut},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use rand_xoshiro::Xoshiro256StarStar;
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -48,19 +92,37 @@ mod test_utils {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub use self::real_utils::RandGen;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use self::test_utils::RandGen;
|
||||
|
||||
/// A source of time that can be faked for tests.
|
||||
#[derive(Clone)]
|
||||
pub enum Clock {
|
||||
/// Use real time
|
||||
Real,
|
||||
/// Fake time for use in tests
|
||||
Fake(),
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Clock(pub Arc<AtomicU64>);
|
||||
|
||||
impl Clock {
|
||||
pub fn new_test() -> Self {
|
||||
Clock(Arc::new(AtomicU64::new(0)))
|
||||
}
|
||||
pub fn set_time(&self, new: u64) {
|
||||
self.0.store(new, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn now_utc(&self) -> DateTime<Utc> {
|
||||
Utc.timestamp_opt(self.0.load(std::sync::atomic::Ordering::Relaxed) as i64, 0)
|
||||
.earliest()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn now_timestamp(&self) -> i64 {
|
||||
self.0.load(std::sync::atomic::Ordering::Relaxed) as i64
|
||||
}
|
||||
|
||||
pub async fn sleep_until(&self, until_ts: i64) {
|
||||
if until_ts < 0 {
|
||||
return;
|
||||
}
|
||||
// Wait for time to advance past the requested timestamp
|
||||
// TODO write a better test sleep implementation
|
||||
while self.now_timestamp() < until_ts {
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,7 +125,7 @@ pub(crate) async fn make_router(
|
||||
.layer(Extension(Arc::new(PasswordHashInflightLimiter::new(1))))
|
||||
.layer(client_ip_source.into_extension())
|
||||
.layer(Extension(Arc::new(ratelimiters)))
|
||||
.layer(Extension(VolatileCodeStore::default()))
|
||||
.layer(Extension(VolatileCodeStore::new(clock.clone())))
|
||||
.layer(Extension(clock))
|
||||
.layer(Extension(randgen));
|
||||
|
||||
@ -144,7 +144,7 @@ pub async fn serve(
|
||||
use eyre::Context;
|
||||
use tracing::info;
|
||||
|
||||
let router = make_router(store, config, secrets, Clock::Real, RandGen).await?;
|
||||
let router = make_router(store, config, secrets, Clock, RandGen).await?;
|
||||
|
||||
info!("Listening on {bind:?}");
|
||||
axum::Server::try_bind(&bind)
|
||||
|
||||
@ -21,7 +21,7 @@ use chrono::{DateTime, Duration, TimeZone, Utc};
|
||||
use eyre::eyre;
|
||||
use eyre::{bail, Context, ContextCompat};
|
||||
use governor::Jitter;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand::Rng;
|
||||
use serde::Deserialize;
|
||||
use sqlx::types::Uuid;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
@ -9,7 +9,6 @@ use axum::{
|
||||
Extension, Form,
|
||||
};
|
||||
|
||||
use chrono::Utc;
|
||||
use eyre::{Context, ContextCompat};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -17,6 +16,7 @@ use tracing::{error, warn};
|
||||
|
||||
use crate::{
|
||||
config::{Configuration, OidcClientConfiguration},
|
||||
utils::{Clock, RandGen},
|
||||
web::{
|
||||
login::LoginSession,
|
||||
make_login_redirect,
|
||||
@ -69,6 +69,8 @@ pub async fn oidc_authorisation(
|
||||
login_session: Option<LoginSession>,
|
||||
Extension(config): Extension<Arc<Configuration>>,
|
||||
Extension(code_store): Extension<VolatileCodeStore>,
|
||||
Extension(clock): Extension<Clock>,
|
||||
Extension(mut randgen): Extension<RandGen>,
|
||||
OriginalUri(uri): OriginalUri,
|
||||
) -> Response {
|
||||
let Query(query) = match query {
|
||||
@ -109,7 +111,7 @@ pub async fn oidc_authorisation(
|
||||
|
||||
// If the application requires consent, then we should ask for that.
|
||||
if !client_config.skip_consent {
|
||||
return show_consent_page(login_session, client_config, &config).await;
|
||||
return show_consent_page(login_session, client_config, Extension(clock), &config).await;
|
||||
}
|
||||
|
||||
// No consent needed: process the authorisation.
|
||||
@ -120,6 +122,8 @@ pub async fn oidc_authorisation(
|
||||
client_id,
|
||||
client_config,
|
||||
&config,
|
||||
&mut randgen,
|
||||
&clock,
|
||||
&code_store,
|
||||
)
|
||||
.await
|
||||
@ -133,11 +137,14 @@ pub struct PostConsentForm {
|
||||
}
|
||||
|
||||
/// `POST /oidc/auth`
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn post_oidc_authorisation_consent(
|
||||
Query(query): Query<AuthorisationQuery>,
|
||||
login_session: Option<LoginSession>,
|
||||
Extension(config): Extension<Arc<Configuration>>,
|
||||
Extension(code_store): Extension<VolatileCodeStore>,
|
||||
Extension(clock): Extension<Clock>,
|
||||
Extension(mut randgen): Extension<RandGen>,
|
||||
OriginalUri(uri): OriginalUri,
|
||||
Form(form): Form<PostConsentForm>,
|
||||
) -> Response {
|
||||
@ -152,11 +159,11 @@ pub async fn post_oidc_authorisation_consent(
|
||||
};
|
||||
|
||||
if login_session
|
||||
.validate_xsrf_token(&form.xsrf, Utc::now())
|
||||
.validate_xsrf_token(&form.xsrf, clock.now_utc())
|
||||
.is_err()
|
||||
{
|
||||
// XSRF token is not valid, so show the consent form again...
|
||||
return show_consent_page(login_session, client_config, &config).await;
|
||||
return show_consent_page(login_session, client_config, Extension(clock), &config).await;
|
||||
}
|
||||
|
||||
match form.action.as_str() {
|
||||
@ -167,6 +174,8 @@ pub async fn post_oidc_authorisation_consent(
|
||||
client_id,
|
||||
client_config,
|
||||
&config,
|
||||
&mut randgen,
|
||||
&clock,
|
||||
&code_store,
|
||||
)
|
||||
.await
|
||||
@ -233,10 +242,11 @@ fn validate_authorisation_basics<'a>(
|
||||
async fn show_consent_page(
|
||||
login_session: LoginSession,
|
||||
client_config: &OidcClientConfiguration,
|
||||
Extension(clock): Extension<Clock>,
|
||||
_config: &Configuration,
|
||||
) -> Response {
|
||||
let xsrf_token = login_session
|
||||
.generate_xsrf_token(Utc::now())
|
||||
.generate_xsrf_token(clock.now_utc())
|
||||
.expect("must be able to create a XSRF token");
|
||||
Html(format!(
|
||||
"hi <u>{}</u>, consent to <u>{}</u>? <form method='POST'><input type='hidden' name='xsrf' value='{}'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>",
|
||||
@ -253,12 +263,15 @@ async fn show_consent_page(
|
||||
/// Preconditions:
|
||||
/// - any required consent from the user has now been obtained
|
||||
/// - query.request_uri has been validated as a safe redirect URI
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn process_authorisation(
|
||||
query: AuthorisationQuery,
|
||||
login_session: LoginSession,
|
||||
client_id: String,
|
||||
_client_config: &OidcClientConfiguration,
|
||||
config: &Configuration,
|
||||
randgen: &mut RandGen,
|
||||
clock: &Clock,
|
||||
code_store: &VolatileCodeStore,
|
||||
) -> Response {
|
||||
assert_eq!(
|
||||
@ -271,7 +284,7 @@ async fn process_authorisation(
|
||||
|
||||
// Generate a 192-bit random code, which fits into exactly 32 base64 characters.
|
||||
// This is an arbitrary choice left to us but I feel a 192-bit value is sufficiently random.
|
||||
let code = AuthCode::generate_new_random();
|
||||
let code = AuthCode::generate_new_random(randgen);
|
||||
let code_base64url = code.to_string();
|
||||
|
||||
// Write down the code and other details in-memory with 10 minute expiry...
|
||||
@ -289,7 +302,7 @@ async fn process_authorisation(
|
||||
user_id: login_session.user_id,
|
||||
user_login_session_id: login_session.login_session_id,
|
||||
},
|
||||
0,
|
||||
clock.now_timestamp() + 600,
|
||||
);
|
||||
|
||||
#[derive(Serialize)]
|
||||
|
||||
@ -6,7 +6,6 @@ use std::{
|
||||
fmt::Display,
|
||||
str::FromStr,
|
||||
sync::{Arc, Mutex},
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
use base64::{display::Base64Display, prelude::BASE64_URL_SAFE_NO_PAD, Engine};
|
||||
@ -14,6 +13,8 @@ use rand::Rng;
|
||||
use sqlx::types::Uuid;
|
||||
use tokio::sync::Notify;
|
||||
|
||||
use crate::utils::{Clock, RandGen};
|
||||
|
||||
/// Display shows the auth code as base64 (URL-safe non-padded).
|
||||
/// FromStr/parse parses the same format.
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
|
||||
@ -61,8 +62,8 @@ impl FromStr for AuthCode {
|
||||
|
||||
impl AuthCode {
|
||||
/// Generate a new authorisation code using the thread's RNG
|
||||
pub fn generate_new_random() -> Self {
|
||||
Self(rand::thread_rng().gen::<[u8; 24]>())
|
||||
pub fn generate_new_random(randgen: &mut RandGen) -> Self {
|
||||
Self(randgen.gen::<[u8; 24]>())
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,7 +113,7 @@ struct VolatileCodeStoreInner {
|
||||
pub conflictable_codes: HashMap<AuthCode, RedeemedAuthCode>,
|
||||
|
||||
/// Time when codes will expire
|
||||
pub expire_codes_at: BTreeSet<(u64, AuthCode)>,
|
||||
pub expire_codes_at: BTreeSet<(i64, AuthCode)>,
|
||||
}
|
||||
|
||||
impl VolatileCodeStoreInner {
|
||||
@ -148,7 +149,7 @@ impl VolatileCodeStoreInner {
|
||||
&mut self,
|
||||
auth_code: AuthCode,
|
||||
auth_code_binding: AuthCodeBinding,
|
||||
expires_at: u64,
|
||||
expires_at: i64,
|
||||
) {
|
||||
self.redeemable_codes
|
||||
.insert(auth_code.clone(), auth_code_binding);
|
||||
@ -156,7 +157,7 @@ impl VolatileCodeStoreInner {
|
||||
}
|
||||
|
||||
/// Removes all expired auth codes and returns the time of the earliest next expiry, if present.
|
||||
pub(self) fn handle_expiry(&mut self, now: u64) -> Option<u64> {
|
||||
pub(self) fn handle_expiry(&mut self, now: i64) -> Option<i64> {
|
||||
loop {
|
||||
let (ts, _auth_code) = self.expire_codes_at.first()?;
|
||||
|
||||
@ -182,15 +183,16 @@ pub struct VolatileCodeStore {
|
||||
inner: Arc<Mutex<VolatileCodeStoreInner>>,
|
||||
}
|
||||
|
||||
impl Default for VolatileCodeStore {
|
||||
fn default() -> Self {
|
||||
impl VolatileCodeStore {
|
||||
/// Create a new instance.
|
||||
pub fn new(clock: Clock) -> Self {
|
||||
let poke = Arc::new(Notify::new());
|
||||
let inner: Arc<Mutex<VolatileCodeStoreInner>> = Default::default();
|
||||
|
||||
{
|
||||
let poke = poke.clone();
|
||||
let inner = inner.clone();
|
||||
tokio::spawn(Self::expirer(inner, poke));
|
||||
tokio::spawn(Self::expirer(inner, poke, clock));
|
||||
}
|
||||
|
||||
VolatileCodeStore { inner, poke }
|
||||
@ -198,19 +200,15 @@ impl Default for VolatileCodeStore {
|
||||
}
|
||||
|
||||
impl VolatileCodeStore {
|
||||
async fn expirer(inner: Arc<Mutex<VolatileCodeStoreInner>>, poke: Arc<Notify>) {
|
||||
let mut next_expiry: Option<u64> = None;
|
||||
async fn expirer(inner: Arc<Mutex<VolatileCodeStoreInner>>, poke: Arc<Notify>, clock: Clock) {
|
||||
let mut next_expiry: Option<i64> = None;
|
||||
loop {
|
||||
match next_expiry {
|
||||
Some(next_expiry) => {
|
||||
let sleep_until = UNIX_EPOCH + Duration::from_secs(next_expiry);
|
||||
let now = SystemTime::now();
|
||||
let sleep_for = sleep_until
|
||||
.duration_since(now)
|
||||
.unwrap_or(Duration::from_secs(60));
|
||||
let sleep_future = clock.sleep_until(next_expiry);
|
||||
tokio::select! {
|
||||
_ = poke.notified() => {},
|
||||
_ = tokio::time::sleep(sleep_for) => {},
|
||||
_ = sleep_future => {},
|
||||
}
|
||||
}
|
||||
None => {
|
||||
@ -218,14 +216,9 @@ impl VolatileCodeStore {
|
||||
}
|
||||
}
|
||||
|
||||
let now = SystemTime::now();
|
||||
next_expiry = {
|
||||
let mut inner = inner.lock().unwrap();
|
||||
inner.handle_expiry(
|
||||
now.duration_since(UNIX_EPOCH)
|
||||
.expect("system clock before unix epoch")
|
||||
.as_secs(),
|
||||
)
|
||||
inner.handle_expiry(clock.now_timestamp())
|
||||
};
|
||||
}
|
||||
}
|
||||
@ -244,7 +237,7 @@ impl VolatileCodeStore {
|
||||
}
|
||||
|
||||
/// Add a new redeemable authorisation code.
|
||||
pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: u64) {
|
||||
pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: i64) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.add_redeemable(auth_code, binding, expires_at);
|
||||
drop(inner);
|
||||
@ -284,7 +277,7 @@ mod test {
|
||||
use rstest::{fixture, rstest};
|
||||
use sqlx::types::Uuid;
|
||||
|
||||
use crate::web::oauth_openid::ext_codes::CodeRedemption;
|
||||
use crate::{utils::Clock, web::oauth_openid::ext_codes::CodeRedemption};
|
||||
|
||||
use super::{AuthCode, AuthCodeBinding, VolatileCodeStore, VolatileCodeStoreInner};
|
||||
|
||||
@ -422,7 +415,8 @@ mod test {
|
||||
/// We can't easily cover everything here but may as well test the basics.
|
||||
#[tokio::test]
|
||||
async fn test_fullfat_store_basic() {
|
||||
let vcs = VolatileCodeStore::default();
|
||||
let clock = Clock::new_test();
|
||||
let vcs = VolatileCodeStore::new(clock.clone());
|
||||
|
||||
vcs.add_redeemable(
|
||||
VALID_CODE.clone(),
|
||||
@ -435,7 +429,7 @@ mod test {
|
||||
user_id: USER_UUID,
|
||||
user_login_session_id: USER_LOGIN_SESSION_ID,
|
||||
},
|
||||
u64::MAX,
|
||||
i64::MAX,
|
||||
);
|
||||
|
||||
assert_matches!(
|
||||
@ -458,6 +452,7 @@ mod test {
|
||||
// given a moment
|
||||
1,
|
||||
);
|
||||
clock.set_time(2);
|
||||
|
||||
vcs.add_redeemable(
|
||||
AuthCode([0; 24]),
|
||||
@ -470,11 +465,11 @@ mod test {
|
||||
user_id: USER_UUID,
|
||||
user_login_session_id: USER_LOGIN_SESSION_ID,
|
||||
},
|
||||
u64::MAX,
|
||||
i64::MAX,
|
||||
);
|
||||
|
||||
// Give a short time for the expiry to take place.
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
tokio::time::sleep(Duration::from_millis(2)).await;
|
||||
|
||||
assert_matches!(
|
||||
vcs.redeem(&VALID_CODE, [1; 32], [2; 32]),
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
//! `/oidc/token`
|
||||
|
||||
use std::{
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::{SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
use axum::{
|
||||
extract::rejection::FormRejection,
|
||||
@ -15,13 +11,13 @@ use axum::{
|
||||
};
|
||||
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
|
||||
use blake2::Blake2s256;
|
||||
use chrono::{Duration, Utc};
|
||||
use chrono::Duration;
|
||||
use eyre::{bail, Context};
|
||||
use josekit::{
|
||||
jws::{alg::rsassa::RsassaJwsAlgorithm::Rs256, JwsHeader},
|
||||
jwt::JwtPayload,
|
||||
};
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use subtle::ConstantTimeEq;
|
||||
@ -30,6 +26,7 @@ use tracing::{debug, error};
|
||||
use crate::{
|
||||
config::{Configuration, SecretConfig},
|
||||
store::IdCoopStore,
|
||||
utils::{Clock, RandGen},
|
||||
};
|
||||
|
||||
use super::ext_codes::{
|
||||
@ -53,12 +50,15 @@ pub struct TokenFormParams {
|
||||
/// OpenID Connect clients call this to exchange an authorisation code they received for an access token.
|
||||
///
|
||||
/// TODO auth_header can be one alternative auth method
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn oidc_token(
|
||||
basic_auth: Option<TypedHeader<Authorization<Basic>>>,
|
||||
Extension(config): Extension<Arc<Configuration>>,
|
||||
Extension(secrets): Extension<Arc<SecretConfig>>,
|
||||
Extension(store): Extension<Arc<IdCoopStore>>,
|
||||
Extension(code_store): Extension<VolatileCodeStore>,
|
||||
Extension(mut randgen): Extension<RandGen>,
|
||||
Extension(clock): Extension<Clock>,
|
||||
form: Result<Form<TokenFormParams>, FormRejection>,
|
||||
) -> impl IntoResponse {
|
||||
let form = match form {
|
||||
@ -110,8 +110,9 @@ pub async fn oidc_token(
|
||||
Json(TokenError {
|
||||
code: TokenErrorCode::InvalidClient,
|
||||
description: "That `client_id` is not recognised here.".to_string(),
|
||||
})
|
||||
).into_response();
|
||||
}),
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
if !bool::from(
|
||||
@ -181,10 +182,10 @@ pub async fn oidc_token(
|
||||
// Create an access token but don't actually issue it yet:
|
||||
// This lets us store the hash of the access token against the redemption of the auth code,
|
||||
// so double redemptions can invalidate the access token appropriately.
|
||||
let access_token = thread_rng().gen::<AccessToken>();
|
||||
let access_token = randgen.gen::<AccessToken>();
|
||||
let access_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(access_token);
|
||||
let access_token_hash: AccessTokenHash = Blake2s256::digest(access_token).into();
|
||||
let refresh_token = thread_rng().gen::<RefreshToken>();
|
||||
let refresh_token = randgen.gen::<RefreshToken>();
|
||||
let refresh_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(refresh_token);
|
||||
let refresh_token_hash: RefreshTokenHash = Blake2s256::digest(refresh_token).into();
|
||||
|
||||
@ -256,7 +257,9 @@ pub async fn oidc_token(
|
||||
}
|
||||
|
||||
// 2. Check the code challenge
|
||||
let Some(computed_code_challenge) = compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier) else {
|
||||
let Some(computed_code_challenge) =
|
||||
compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier)
|
||||
else {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(TokenError {
|
||||
@ -287,28 +290,29 @@ pub async fn oidc_token(
|
||||
let user_login_session_id = binding.user_login_session_id;
|
||||
|
||||
// Issue access token for a new session
|
||||
let clock2 = clock.clone();
|
||||
match store
|
||||
.txn(move |mut txn| {
|
||||
Box::pin(async move {
|
||||
let Some(session_id) = txn
|
||||
.create_application_session(user_id, &application_id, user_login_session_id)
|
||||
.await
|
||||
.context("create_application_session")? else {
|
||||
return Ok(Err(
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(TokenError {
|
||||
code: TokenErrorCode::InvalidGrant,
|
||||
description: "Auth code has expired or was not valid.".to_owned(),
|
||||
}),
|
||||
).into_response()
|
||||
));
|
||||
.context("create_application_session")?
|
||||
else {
|
||||
return Ok(Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(TokenError {
|
||||
code: TokenErrorCode::InvalidGrant,
|
||||
description: "Auth code has expired or was not valid.".to_owned(),
|
||||
}),
|
||||
)
|
||||
.into_response()));
|
||||
};
|
||||
txn.issue_access_token(
|
||||
&access_token_hash,
|
||||
session_id,
|
||||
// TODO(expiry) Support custom expiry, not 100 years
|
||||
Utc::now() + Duration::days(365 * 100),
|
||||
clock2.now_utc() + Duration::days(365 * 100),
|
||||
)
|
||||
.await
|
||||
.context("issue_access_token")?;
|
||||
@ -316,7 +320,7 @@ pub async fn oidc_token(
|
||||
&refresh_token_hash,
|
||||
session_id,
|
||||
// TODO(expiry) Support custom expiry, not 100 years
|
||||
Utc::now() + Duration::days(365 * 100),
|
||||
clock2.now_utc() + Duration::days(365 * 100),
|
||||
)
|
||||
.await
|
||||
.context("issue_refresh_token")?;
|
||||
@ -344,10 +348,7 @@ pub async fn oidc_token(
|
||||
}
|
||||
}
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("Before unix epoch?")
|
||||
.as_secs();
|
||||
let now = clock.now_timestamp();
|
||||
// TODO(expiry) Support custom expiry times (not just 100 years)
|
||||
let exp = now + 100 * 365 * 86400;
|
||||
let sub = binding.user_id.hyphenated().to_string();
|
||||
@ -385,7 +386,9 @@ pub async fn oidc_token(
|
||||
}
|
||||
|
||||
fn make_id_token(id_token: IdToken, secrets: &SecretConfig) -> eyre::Result<String> {
|
||||
let Ok(serde_json::Value::Object(map)) = serde_json::to_value(id_token).context("failed to serialise ID Token content") else {
|
||||
let Ok(serde_json::Value::Object(map)) =
|
||||
serde_json::to_value(id_token).context("failed to serialise ID Token content")
|
||||
else {
|
||||
bail!("ID Token not a map");
|
||||
};
|
||||
|
||||
@ -447,9 +450,9 @@ pub struct IdToken {
|
||||
/// Implementers MAY provide for some small leeway, usually no more than a few minutes, to account for clock skew.
|
||||
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
||||
/// See RFC 3339 [RFC3339] for details regarding date/times in general and UTC in particular.
|
||||
pub exp: u64,
|
||||
pub exp: i64,
|
||||
/// REQUIRED. Time at which the JWT was issued. Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
||||
pub iat: u64,
|
||||
pub iat: i64,
|
||||
/// Time when the End-User authentication occurred.
|
||||
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
||||
/// When a max_age request is made or when auth_time is requested as an Essential Claim, then this Claim is REQUIRED; otherwise, its inclusion is OPTIONAL.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user