Tests for the full flow up to and including the token endpoint
Signed-off-by: Olivier 'reivilibre <olivier@librepush.net>
This commit is contained in:
parent
f2b0a64fb0
commit
13e6cd5361
10
src/store.rs
10
src/store.rs
@ -110,7 +110,7 @@ pub struct UserInfo {
|
|||||||
|
|
||||||
/// A wrapper around a database transaction with some database methods on it.
|
/// A wrapper around a database transaction with some database methods on it.
|
||||||
pub struct IdCoopStoreTxn<'a, 'txn> {
|
pub struct IdCoopStoreTxn<'a, 'txn> {
|
||||||
txn: &'a mut Transaction<'txn, Postgres>,
|
pub(crate) txn: &'a mut Transaction<'txn, Postgres>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
|
impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
|
||||||
@ -341,7 +341,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
|
|||||||
.await
|
.await
|
||||||
.context("failed to lookup login session")?;
|
.context("failed to lookup login session")?;
|
||||||
|
|
||||||
let Some(row) = row_opt else { return Ok(None); };
|
let Some(row) = row_opt else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Some(LoginSession {
|
Ok(Some(LoginSession {
|
||||||
user_name: row.user_name,
|
user_name: row.user_name,
|
||||||
@ -373,7 +375,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
|
|||||||
.await
|
.await
|
||||||
.context("failed to lookup application session")?;
|
.context("failed to lookup application session")?;
|
||||||
|
|
||||||
let Some(row) = row_opt else { return Ok(None); };
|
let Some(row) = row_opt else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Some(ApplicationSession {
|
Ok(Some(ApplicationSession {
|
||||||
application_session_id: row.session_id,
|
application_session_id: row.session_id,
|
||||||
|
|||||||
@ -3,6 +3,7 @@ use std::sync::Arc;
|
|||||||
use axum::Router;
|
use axum::Router;
|
||||||
use confique::{Config, Partial};
|
use confique::{Config, Partial};
|
||||||
use josekit::jwk::alg::rsa::RsaKeyPair;
|
use josekit::jwk::alg::rsa::RsaKeyPair;
|
||||||
|
use metrics::atomics::AtomicU64;
|
||||||
use pgtemp::PgTempDB;
|
use pgtemp::PgTempDB;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use rand_xoshiro::Xoshiro256StarStar;
|
use rand_xoshiro::Xoshiro256StarStar;
|
||||||
@ -20,6 +21,7 @@ struct TestSystem {
|
|||||||
web: Router,
|
web: Router,
|
||||||
config: Arc<Configuration>,
|
config: Arc<Configuration>,
|
||||||
store: Arc<IdCoopStore>,
|
store: Arc<IdCoopStore>,
|
||||||
|
clock: Clock,
|
||||||
}
|
}
|
||||||
|
|
||||||
const RSA_KEY_PAIR_PEM: &[u8] = include_bytes!("tests/keypair.pem");
|
const RSA_KEY_PAIR_PEM: &[u8] = include_bytes!("tests/keypair.pem");
|
||||||
@ -85,13 +87,13 @@ async fn basic_system() -> TestSystem {
|
|||||||
|
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let store = Arc::new(store);
|
let store = Arc::new(store);
|
||||||
let clock = Clock::Fake();
|
let clock = Clock(Arc::new(AtomicU64::new(0)));
|
||||||
let randgen = RandGen(Xoshiro256StarStar::seed_from_u64(424242));
|
let randgen = RandGen(Xoshiro256StarStar::seed_from_u64(424242));
|
||||||
let router = make_router(
|
let router = make_router(
|
||||||
store.clone(),
|
store.clone(),
|
||||||
config.clone(),
|
config.clone(),
|
||||||
Arc::new(secrets),
|
Arc::new(secrets),
|
||||||
clock,
|
clock.clone(),
|
||||||
randgen,
|
randgen,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -102,6 +104,7 @@ async fn basic_system() -> TestSystem {
|
|||||||
web: router,
|
web: router,
|
||||||
config,
|
config,
|
||||||
store,
|
store,
|
||||||
|
clock,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ expression: "(headers, text)"
|
|||||||
---
|
---
|
||||||
- content-length: "55"
|
- content-length: "55"
|
||||||
content-type: text/plain; charset=utf-8
|
content-type: text/plain; charset=utf-8
|
||||||
location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=challenging&code_challenge_method=S256&nonce=noncey"
|
location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M&code_challenge_method=S256&nonce=noncey"
|
||||||
set-cookie: __Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000
|
set-cookie: __Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000
|
||||||
x-frame-options: DENY
|
x-frame-options: DENY
|
||||||
- Logged in. Redirecting you back to what you were doing.
|
- Logged in. Redirecting you back to what you were doing.
|
||||||
|
|||||||
@ -0,0 +1,8 @@
|
|||||||
|
---
|
||||||
|
source: src/tests/test_oidc_auth_flow.rs
|
||||||
|
expression: "(headers, text)"
|
||||||
|
---
|
||||||
|
- content-length: "288"
|
||||||
|
content-type: text/html; charset=utf-8
|
||||||
|
x-frame-options: DENY
|
||||||
|
- "hi <u>robert</u>, consent to <u>AClient</u>? <form method='POST'><input type='hidden' name='xsrf' value='0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>"
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
---
|
||||||
|
source: src/tests/test_oidc_auth_flow.rs
|
||||||
|
expression: "(headers, text)"
|
||||||
|
---
|
||||||
|
- content-length: "46"
|
||||||
|
content-type: text/plain; charset=utf-8
|
||||||
|
location: "http://aclient.example.com/redirect?code=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO&state=wombat&iss=http%3A%2F%2Fissuer.example.com"
|
||||||
|
x-frame-options: DENY
|
||||||
|
- Authorisation succeeded; redirecting you back.
|
||||||
@ -0,0 +1,14 @@
|
|||||||
|
---
|
||||||
|
source: src/tests/test_oidc_auth_flow.rs
|
||||||
|
expression: "(headers, json)"
|
||||||
|
---
|
||||||
|
- access-control-allow-origin: "*"
|
||||||
|
access-control-expose-headers: "*"
|
||||||
|
content-length: "803"
|
||||||
|
content-type: application/json
|
||||||
|
- access_token: HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE
|
||||||
|
expires_in: 31536000
|
||||||
|
id_token: eyJ0eXAiOiJKV1QiLCJraWQiOiJ0aGVrZXkiLCJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwic3ViIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwIiwiYXVkIjoiYWNsaWVudCIsImV4cCI6MzE1MzYwMDAzMCwiaWF0IjozMCwiYXV0aF90aW1lIjowLCJub25jZSI6Im5vbmNleSJ9.QQEhDgAcF2vBg2J6ledDzk_4ks4GyquJgMSE4KUREtTUVZbLpfa52sro8lPiBnFPCOz_DkSfpm4OQq8429mwcqfoyS-uBjtgPq7eij7kOa3BTrb9eC8rScGuDX0wJ9XZV-v0f3dun_sYhvH3smLqPoTF4wxtgT5b_2SCmnuqL2cmKN-GFox4mjmdoPzQxhAyTKlj_HkGHQjkl-nP96-71QeM5KwyLQes_OWU2HSEt9uiemUsEr4pMPv-po7QkrU5p2sJc5udcaUAOtuV9tpt5qg8P9TPWYo4M1GbMsbTyWYDhsmtusNKB6N2srwZwB9QwgE4DxoeoqmKlJf0BGF4Bg
|
||||||
|
refresh_token: bn0vvGrpCiQDMaH7MRw7FgHuPqoEYel0IFGYv_Z6E7o
|
||||||
|
scope: openid
|
||||||
|
token_type: Bearer
|
||||||
@ -7,6 +7,7 @@ use axum_test_helper::{TestClient, TestResponse};
|
|||||||
use insta::assert_yaml_snapshot;
|
use insta::assert_yaml_snapshot;
|
||||||
|
|
||||||
use maplit::btreemap;
|
use maplit::btreemap;
|
||||||
|
use sqlx::types::Uuid;
|
||||||
|
|
||||||
use crate::{passwords::create_password_hash, store::CreateUser, tests::basic_system};
|
use crate::{passwords::create_password_hash, store::CreateUser, tests::basic_system};
|
||||||
|
|
||||||
@ -21,10 +22,15 @@ async fn dump_resp_text(
|
|||||||
.headers()
|
.headers()
|
||||||
.clone()
|
.clone()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(k, v)| (k.unwrap().to_string(), v.to_str().unwrap().to_owned()))
|
// skip headers that are duplicate keys.
|
||||||
|
// We don't care about theme for our purposes.
|
||||||
|
.filter_map(|(k, v)| Some((k?.to_string(), v.to_str().unwrap().to_owned())))
|
||||||
.collect();
|
.collect();
|
||||||
// Remove date because it's not stable across tests!
|
// Remove date because it's not stable across tests!
|
||||||
headers.remove("date");
|
headers.remove("date");
|
||||||
|
// Remove vary because it has multiple values and we don't want to
|
||||||
|
// introduce instability into our tests by only allowing one through.
|
||||||
|
headers.remove("vary");
|
||||||
let text = resp.text().await;
|
let text = resp.text().await;
|
||||||
eprintln!("=== Response for {req_name} ===");
|
eprintln!("=== Response for {req_name} ===");
|
||||||
eprintln!("Status: {status:?}");
|
eprintln!("Status: {status:?}");
|
||||||
@ -36,21 +42,20 @@ async fn dump_resp_text(
|
|||||||
|
|
||||||
/// Tests the full flow...
|
/// Tests the full flow...
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_todo() {
|
async fn test_full_flow() {
|
||||||
let sys = basic_system().await;
|
let sys = basic_system().await;
|
||||||
|
|
||||||
|
let uuid = Uuid::nil();
|
||||||
let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap();
|
let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap();
|
||||||
let _: () = sys
|
let _: () = sys
|
||||||
.store
|
.store
|
||||||
.txn(|mut txn| {
|
.txn(|mut txn| {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
txn.create_user(CreateUser {
|
sqlx::query(
|
||||||
user_login_name: "robert".to_owned(),
|
"INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id",
|
||||||
password_hash: Some(pwhash),
|
).bind("robert").bind(uuid).bind(pwhash).bind(false)
|
||||||
locked: false,
|
.fetch_one(&mut **txn.txn)
|
||||||
})
|
.await.unwrap();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@ -59,22 +64,28 @@ async fn test_todo() {
|
|||||||
|
|
||||||
let client = TestClient::new(sys.web);
|
let client = TestClient::new(sys.web);
|
||||||
|
|
||||||
|
///// These requests are on behalf of the user's browser /////
|
||||||
|
|
||||||
|
const CODE_VERIFIER: &str = "verifying";
|
||||||
|
// base64(sha256("verifying"))
|
||||||
|
const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M";
|
||||||
|
|
||||||
// 1. /auth request
|
// 1. /auth request
|
||||||
const LOGIN_URL: &str = "/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3Dchallenging%26code_challenge_method%3DS256%26nonce%3Dnoncey";
|
let login_url = format!("/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3D{CODE_CHALLENGE}%26code_challenge_method%3DS256%26nonce%3Dnoncey");
|
||||||
let resp = client.get("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=challenging&code_challenge_method=S256&nonce=noncey").send().await;
|
let resp = client.get(&format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey")).send().await;
|
||||||
let (status, headers, _text) = dump_resp_text("1. /auth request", resp).await;
|
let (status, headers, _text) = dump_resp_text("1. /auth request", resp).await;
|
||||||
assert_eq!(status, 302);
|
assert_eq!(status, 302);
|
||||||
assert_eq!(headers.get("location").unwrap(), LOGIN_URL);
|
assert_eq!(headers.get("location").unwrap(), &login_url);
|
||||||
|
|
||||||
// 2. /login request
|
// 2. /login request
|
||||||
let resp = client.get(LOGIN_URL).send().await;
|
let resp = client.get(&login_url).send().await;
|
||||||
let (status, headers, text) = dump_resp_text("2. /login request", resp).await;
|
let (status, headers, text) = dump_resp_text("2. /login request", resp).await;
|
||||||
assert_eq!(status, 200);
|
assert_eq!(status, 200);
|
||||||
assert_yaml_snapshot!("2/login", (headers, text));
|
assert_yaml_snapshot!("2/login", (headers, text));
|
||||||
|
|
||||||
// 3. /login request with credentials
|
// 3. /login request with credentials
|
||||||
let resp = client
|
let resp = client
|
||||||
.post(LOGIN_URL)
|
.post(&login_url)
|
||||||
.form(&btreemap! {
|
.form(&btreemap! {
|
||||||
"username" => "robert",
|
"username" => "robert",
|
||||||
"password" => "secret",
|
"password" => "secret",
|
||||||
@ -87,5 +98,58 @@ async fn test_todo() {
|
|||||||
.await;
|
.await;
|
||||||
let (status, headers, text) = dump_resp_text("3. /login request with credentials", resp).await;
|
let (status, headers, text) = dump_resp_text("3. /login request with credentials", resp).await;
|
||||||
assert_eq!(status, 302);
|
assert_eq!(status, 302);
|
||||||
|
let auth_loc = headers.get("location").unwrap().to_owned();
|
||||||
assert_yaml_snapshot!("3/login", (headers, text));
|
assert_yaml_snapshot!("3/login", (headers, text));
|
||||||
|
|
||||||
|
// 4. /auth request
|
||||||
|
let resp = client
|
||||||
|
.get(&auth_loc)
|
||||||
|
.header(
|
||||||
|
"Cookie",
|
||||||
|
"__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE",
|
||||||
|
)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
let (status, headers, text) = dump_resp_text("4. GET /auth after login", resp).await;
|
||||||
|
assert_eq!(status, 200);
|
||||||
|
assert_yaml_snapshot!("4/auth", (headers, text));
|
||||||
|
|
||||||
|
sys.clock.set_time(30);
|
||||||
|
|
||||||
|
// 5. /auth request with confirmation
|
||||||
|
let resp = client
|
||||||
|
.post(&auth_loc)
|
||||||
|
.header(
|
||||||
|
"Cookie",
|
||||||
|
"__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE",
|
||||||
|
)
|
||||||
|
.form(&btreemap! {
|
||||||
|
"action" => "accept",
|
||||||
|
"xsrf" => "0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y",
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
let (status, headers, text) = dump_resp_text("5. POST /auth after confirmation", resp).await;
|
||||||
|
assert_eq!(status, 302);
|
||||||
|
// Note this snapshot includes a Location: header back to the client application
|
||||||
|
assert_yaml_snapshot!("5/auth", (headers, text));
|
||||||
|
|
||||||
|
///// At this point, we make requests on behalf of the client /////
|
||||||
|
|
||||||
|
// 6. /token request
|
||||||
|
let resp = client
|
||||||
|
.post("/oidc/token")
|
||||||
|
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
|
||||||
|
.form(&btreemap! {
|
||||||
|
"code" => "HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO",
|
||||||
|
"code_verifier" => CODE_VERIFIER,
|
||||||
|
"grant_type" => "authorization_code",
|
||||||
|
"redirect_uri" => "http://aclient.example.com/redirect",
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
let (status, headers, text) = dump_resp_text("6. POST /token", resp).await;
|
||||||
|
assert_eq!(status, 200);
|
||||||
|
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
|
||||||
|
assert_yaml_snapshot!("6/token", (headers, json));
|
||||||
}
|
}
|
||||||
|
|||||||
94
src/utils.rs
94
src/utils.rs
@ -1,7 +1,16 @@
|
|||||||
//! Miscellaneous utilities
|
//! Miscellaneous utilities
|
||||||
|
|
||||||
|
#[cfg(not(test))]
|
||||||
|
pub use self::real_utils::{Clock, RandGen};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub use self::test_utils::{Clock, RandGen};
|
||||||
|
|
||||||
#[cfg(not(test))]
|
#[cfg(not(test))]
|
||||||
mod real_utils {
|
mod real_utils {
|
||||||
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
use rand::{thread_rng, RngCore};
|
use rand::{thread_rng, RngCore};
|
||||||
|
|
||||||
/// A source of random numbers that can be faked for tests.
|
/// A source of random numbers that can be faked for tests.
|
||||||
@ -25,12 +34,47 @@ mod real_utils {
|
|||||||
thread_rng().try_fill_bytes(dest)
|
thread_rng().try_fill_bytes(dest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A source of time that can be faked for tests.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Clock;
|
||||||
|
|
||||||
|
impl Clock {
|
||||||
|
/// Returns the current time as a `DateTime<Utc>`.
|
||||||
|
pub fn now_utc(&self) -> DateTime<Utc> {
|
||||||
|
Utc::now()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the current time as a i64 of seconds since the Unix epoch.
|
||||||
|
pub fn now_timestamp(&self) -> i64 {
|
||||||
|
Utc::now().timestamp()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sleep until the given timestamp.
|
||||||
|
pub async fn sleep_until(&self, until_ts: i64) {
|
||||||
|
if until_ts < 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let sleep_until = UNIX_EPOCH + Duration::from_secs(until_ts as u64);
|
||||||
|
let now = SystemTime::now();
|
||||||
|
let sleep_for = sleep_until
|
||||||
|
.duration_since(now)
|
||||||
|
.unwrap_or(Duration::from_secs(0));
|
||||||
|
tokio::time::sleep(sleep_for).await
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test_utils {
|
mod test_utils {
|
||||||
use std::ops::{Deref, DerefMut};
|
use std::sync::atomic::AtomicU64;
|
||||||
|
use std::time::Duration;
|
||||||
|
use std::{
|
||||||
|
ops::{Deref, DerefMut},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
use chrono::{DateTime, TimeZone, Utc};
|
||||||
use rand_xoshiro::Xoshiro256StarStar;
|
use rand_xoshiro::Xoshiro256StarStar;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -48,19 +92,37 @@ mod test_utils {
|
|||||||
&mut self.0
|
&mut self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
#[derive(Clone)]
|
||||||
#[cfg(not(test))]
|
pub struct Clock(pub Arc<AtomicU64>);
|
||||||
pub use self::real_utils::RandGen;
|
|
||||||
|
impl Clock {
|
||||||
#[cfg(test)]
|
pub fn new_test() -> Self {
|
||||||
pub use self::test_utils::RandGen;
|
Clock(Arc::new(AtomicU64::new(0)))
|
||||||
|
}
|
||||||
/// A source of time that can be faked for tests.
|
pub fn set_time(&self, new: u64) {
|
||||||
#[derive(Clone)]
|
self.0.store(new, std::sync::atomic::Ordering::Relaxed);
|
||||||
pub enum Clock {
|
}
|
||||||
/// Use real time
|
|
||||||
Real,
|
pub fn now_utc(&self) -> DateTime<Utc> {
|
||||||
/// Fake time for use in tests
|
Utc.timestamp_opt(self.0.load(std::sync::atomic::Ordering::Relaxed) as i64, 0)
|
||||||
Fake(),
|
.earliest()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn now_timestamp(&self) -> i64 {
|
||||||
|
self.0.load(std::sync::atomic::Ordering::Relaxed) as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn sleep_until(&self, until_ts: i64) {
|
||||||
|
if until_ts < 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Wait for time to advance past the requested timestamp
|
||||||
|
// TODO write a better test sleep implementation
|
||||||
|
while self.now_timestamp() < until_ts {
|
||||||
|
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -125,7 +125,7 @@ pub(crate) async fn make_router(
|
|||||||
.layer(Extension(Arc::new(PasswordHashInflightLimiter::new(1))))
|
.layer(Extension(Arc::new(PasswordHashInflightLimiter::new(1))))
|
||||||
.layer(client_ip_source.into_extension())
|
.layer(client_ip_source.into_extension())
|
||||||
.layer(Extension(Arc::new(ratelimiters)))
|
.layer(Extension(Arc::new(ratelimiters)))
|
||||||
.layer(Extension(VolatileCodeStore::default()))
|
.layer(Extension(VolatileCodeStore::new(clock.clone())))
|
||||||
.layer(Extension(clock))
|
.layer(Extension(clock))
|
||||||
.layer(Extension(randgen));
|
.layer(Extension(randgen));
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ pub async fn serve(
|
|||||||
use eyre::Context;
|
use eyre::Context;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
let router = make_router(store, config, secrets, Clock::Real, RandGen).await?;
|
let router = make_router(store, config, secrets, Clock, RandGen).await?;
|
||||||
|
|
||||||
info!("Listening on {bind:?}");
|
info!("Listening on {bind:?}");
|
||||||
axum::Server::try_bind(&bind)
|
axum::Server::try_bind(&bind)
|
||||||
|
|||||||
@ -21,7 +21,7 @@ use chrono::{DateTime, Duration, TimeZone, Utc};
|
|||||||
use eyre::eyre;
|
use eyre::eyre;
|
||||||
use eyre::{bail, Context, ContextCompat};
|
use eyre::{bail, Context, ContextCompat};
|
||||||
use governor::Jitter;
|
use governor::Jitter;
|
||||||
use rand::{thread_rng, Rng};
|
use rand::Rng;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use sqlx::types::Uuid;
|
use sqlx::types::Uuid;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
|||||||
@ -9,7 +9,6 @@ use axum::{
|
|||||||
Extension, Form,
|
Extension, Form,
|
||||||
};
|
};
|
||||||
|
|
||||||
use chrono::Utc;
|
|
||||||
use eyre::{Context, ContextCompat};
|
use eyre::{Context, ContextCompat};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -17,6 +16,7 @@ use tracing::{error, warn};
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::{Configuration, OidcClientConfiguration},
|
config::{Configuration, OidcClientConfiguration},
|
||||||
|
utils::{Clock, RandGen},
|
||||||
web::{
|
web::{
|
||||||
login::LoginSession,
|
login::LoginSession,
|
||||||
make_login_redirect,
|
make_login_redirect,
|
||||||
@ -69,6 +69,8 @@ pub async fn oidc_authorisation(
|
|||||||
login_session: Option<LoginSession>,
|
login_session: Option<LoginSession>,
|
||||||
Extension(config): Extension<Arc<Configuration>>,
|
Extension(config): Extension<Arc<Configuration>>,
|
||||||
Extension(code_store): Extension<VolatileCodeStore>,
|
Extension(code_store): Extension<VolatileCodeStore>,
|
||||||
|
Extension(clock): Extension<Clock>,
|
||||||
|
Extension(mut randgen): Extension<RandGen>,
|
||||||
OriginalUri(uri): OriginalUri,
|
OriginalUri(uri): OriginalUri,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let Query(query) = match query {
|
let Query(query) = match query {
|
||||||
@ -109,7 +111,7 @@ pub async fn oidc_authorisation(
|
|||||||
|
|
||||||
// If the application requires consent, then we should ask for that.
|
// If the application requires consent, then we should ask for that.
|
||||||
if !client_config.skip_consent {
|
if !client_config.skip_consent {
|
||||||
return show_consent_page(login_session, client_config, &config).await;
|
return show_consent_page(login_session, client_config, Extension(clock), &config).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// No consent needed: process the authorisation.
|
// No consent needed: process the authorisation.
|
||||||
@ -120,6 +122,8 @@ pub async fn oidc_authorisation(
|
|||||||
client_id,
|
client_id,
|
||||||
client_config,
|
client_config,
|
||||||
&config,
|
&config,
|
||||||
|
&mut randgen,
|
||||||
|
&clock,
|
||||||
&code_store,
|
&code_store,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -133,11 +137,14 @@ pub struct PostConsentForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// `POST /oidc/auth`
|
/// `POST /oidc/auth`
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn post_oidc_authorisation_consent(
|
pub async fn post_oidc_authorisation_consent(
|
||||||
Query(query): Query<AuthorisationQuery>,
|
Query(query): Query<AuthorisationQuery>,
|
||||||
login_session: Option<LoginSession>,
|
login_session: Option<LoginSession>,
|
||||||
Extension(config): Extension<Arc<Configuration>>,
|
Extension(config): Extension<Arc<Configuration>>,
|
||||||
Extension(code_store): Extension<VolatileCodeStore>,
|
Extension(code_store): Extension<VolatileCodeStore>,
|
||||||
|
Extension(clock): Extension<Clock>,
|
||||||
|
Extension(mut randgen): Extension<RandGen>,
|
||||||
OriginalUri(uri): OriginalUri,
|
OriginalUri(uri): OriginalUri,
|
||||||
Form(form): Form<PostConsentForm>,
|
Form(form): Form<PostConsentForm>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
@ -152,11 +159,11 @@ pub async fn post_oidc_authorisation_consent(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if login_session
|
if login_session
|
||||||
.validate_xsrf_token(&form.xsrf, Utc::now())
|
.validate_xsrf_token(&form.xsrf, clock.now_utc())
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
// XSRF token is not valid, so show the consent form again...
|
// XSRF token is not valid, so show the consent form again...
|
||||||
return show_consent_page(login_session, client_config, &config).await;
|
return show_consent_page(login_session, client_config, Extension(clock), &config).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
match form.action.as_str() {
|
match form.action.as_str() {
|
||||||
@ -167,6 +174,8 @@ pub async fn post_oidc_authorisation_consent(
|
|||||||
client_id,
|
client_id,
|
||||||
client_config,
|
client_config,
|
||||||
&config,
|
&config,
|
||||||
|
&mut randgen,
|
||||||
|
&clock,
|
||||||
&code_store,
|
&code_store,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -233,10 +242,11 @@ fn validate_authorisation_basics<'a>(
|
|||||||
async fn show_consent_page(
|
async fn show_consent_page(
|
||||||
login_session: LoginSession,
|
login_session: LoginSession,
|
||||||
client_config: &OidcClientConfiguration,
|
client_config: &OidcClientConfiguration,
|
||||||
|
Extension(clock): Extension<Clock>,
|
||||||
_config: &Configuration,
|
_config: &Configuration,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let xsrf_token = login_session
|
let xsrf_token = login_session
|
||||||
.generate_xsrf_token(Utc::now())
|
.generate_xsrf_token(clock.now_utc())
|
||||||
.expect("must be able to create a XSRF token");
|
.expect("must be able to create a XSRF token");
|
||||||
Html(format!(
|
Html(format!(
|
||||||
"hi <u>{}</u>, consent to <u>{}</u>? <form method='POST'><input type='hidden' name='xsrf' value='{}'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>",
|
"hi <u>{}</u>, consent to <u>{}</u>? <form method='POST'><input type='hidden' name='xsrf' value='{}'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>",
|
||||||
@ -253,12 +263,15 @@ async fn show_consent_page(
|
|||||||
/// Preconditions:
|
/// Preconditions:
|
||||||
/// - any required consent from the user has now been obtained
|
/// - any required consent from the user has now been obtained
|
||||||
/// - query.request_uri has been validated as a safe redirect URI
|
/// - query.request_uri has been validated as a safe redirect URI
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn process_authorisation(
|
async fn process_authorisation(
|
||||||
query: AuthorisationQuery,
|
query: AuthorisationQuery,
|
||||||
login_session: LoginSession,
|
login_session: LoginSession,
|
||||||
client_id: String,
|
client_id: String,
|
||||||
_client_config: &OidcClientConfiguration,
|
_client_config: &OidcClientConfiguration,
|
||||||
config: &Configuration,
|
config: &Configuration,
|
||||||
|
randgen: &mut RandGen,
|
||||||
|
clock: &Clock,
|
||||||
code_store: &VolatileCodeStore,
|
code_store: &VolatileCodeStore,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -271,7 +284,7 @@ async fn process_authorisation(
|
|||||||
|
|
||||||
// Generate a 192-bit random code, which fits into exactly 32 base64 characters.
|
// Generate a 192-bit random code, which fits into exactly 32 base64 characters.
|
||||||
// This is an arbitrary choice left to us but I feel a 192-bit value is sufficiently random.
|
// This is an arbitrary choice left to us but I feel a 192-bit value is sufficiently random.
|
||||||
let code = AuthCode::generate_new_random();
|
let code = AuthCode::generate_new_random(randgen);
|
||||||
let code_base64url = code.to_string();
|
let code_base64url = code.to_string();
|
||||||
|
|
||||||
// Write down the code and other details in-memory with 10 minute expiry...
|
// Write down the code and other details in-memory with 10 minute expiry...
|
||||||
@ -289,7 +302,7 @@ async fn process_authorisation(
|
|||||||
user_id: login_session.user_id,
|
user_id: login_session.user_id,
|
||||||
user_login_session_id: login_session.login_session_id,
|
user_login_session_id: login_session.login_session_id,
|
||||||
},
|
},
|
||||||
0,
|
clock.now_timestamp() + 600,
|
||||||
);
|
);
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
|
|||||||
@ -6,7 +6,6 @@ use std::{
|
|||||||
fmt::Display,
|
fmt::Display,
|
||||||
str::FromStr,
|
str::FromStr,
|
||||||
sync::{Arc, Mutex},
|
sync::{Arc, Mutex},
|
||||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use base64::{display::Base64Display, prelude::BASE64_URL_SAFE_NO_PAD, Engine};
|
use base64::{display::Base64Display, prelude::BASE64_URL_SAFE_NO_PAD, Engine};
|
||||||
@ -14,6 +13,8 @@ use rand::Rng;
|
|||||||
use sqlx::types::Uuid;
|
use sqlx::types::Uuid;
|
||||||
use tokio::sync::Notify;
|
use tokio::sync::Notify;
|
||||||
|
|
||||||
|
use crate::utils::{Clock, RandGen};
|
||||||
|
|
||||||
/// Display shows the auth code as base64 (URL-safe non-padded).
|
/// Display shows the auth code as base64 (URL-safe non-padded).
|
||||||
/// FromStr/parse parses the same format.
|
/// FromStr/parse parses the same format.
|
||||||
#[derive(Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
|
#[derive(Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
|
||||||
@ -61,8 +62,8 @@ impl FromStr for AuthCode {
|
|||||||
|
|
||||||
impl AuthCode {
|
impl AuthCode {
|
||||||
/// Generate a new authorisation code using the thread's RNG
|
/// Generate a new authorisation code using the thread's RNG
|
||||||
pub fn generate_new_random() -> Self {
|
pub fn generate_new_random(randgen: &mut RandGen) -> Self {
|
||||||
Self(rand::thread_rng().gen::<[u8; 24]>())
|
Self(randgen.gen::<[u8; 24]>())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,7 +113,7 @@ struct VolatileCodeStoreInner {
|
|||||||
pub conflictable_codes: HashMap<AuthCode, RedeemedAuthCode>,
|
pub conflictable_codes: HashMap<AuthCode, RedeemedAuthCode>,
|
||||||
|
|
||||||
/// Time when codes will expire
|
/// Time when codes will expire
|
||||||
pub expire_codes_at: BTreeSet<(u64, AuthCode)>,
|
pub expire_codes_at: BTreeSet<(i64, AuthCode)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VolatileCodeStoreInner {
|
impl VolatileCodeStoreInner {
|
||||||
@ -148,7 +149,7 @@ impl VolatileCodeStoreInner {
|
|||||||
&mut self,
|
&mut self,
|
||||||
auth_code: AuthCode,
|
auth_code: AuthCode,
|
||||||
auth_code_binding: AuthCodeBinding,
|
auth_code_binding: AuthCodeBinding,
|
||||||
expires_at: u64,
|
expires_at: i64,
|
||||||
) {
|
) {
|
||||||
self.redeemable_codes
|
self.redeemable_codes
|
||||||
.insert(auth_code.clone(), auth_code_binding);
|
.insert(auth_code.clone(), auth_code_binding);
|
||||||
@ -156,7 +157,7 @@ impl VolatileCodeStoreInner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Removes all expired auth codes and returns the time of the earliest next expiry, if present.
|
/// Removes all expired auth codes and returns the time of the earliest next expiry, if present.
|
||||||
pub(self) fn handle_expiry(&mut self, now: u64) -> Option<u64> {
|
pub(self) fn handle_expiry(&mut self, now: i64) -> Option<i64> {
|
||||||
loop {
|
loop {
|
||||||
let (ts, _auth_code) = self.expire_codes_at.first()?;
|
let (ts, _auth_code) = self.expire_codes_at.first()?;
|
||||||
|
|
||||||
@ -182,15 +183,16 @@ pub struct VolatileCodeStore {
|
|||||||
inner: Arc<Mutex<VolatileCodeStoreInner>>,
|
inner: Arc<Mutex<VolatileCodeStoreInner>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for VolatileCodeStore {
|
impl VolatileCodeStore {
|
||||||
fn default() -> Self {
|
/// Create a new instance.
|
||||||
|
pub fn new(clock: Clock) -> Self {
|
||||||
let poke = Arc::new(Notify::new());
|
let poke = Arc::new(Notify::new());
|
||||||
let inner: Arc<Mutex<VolatileCodeStoreInner>> = Default::default();
|
let inner: Arc<Mutex<VolatileCodeStoreInner>> = Default::default();
|
||||||
|
|
||||||
{
|
{
|
||||||
let poke = poke.clone();
|
let poke = poke.clone();
|
||||||
let inner = inner.clone();
|
let inner = inner.clone();
|
||||||
tokio::spawn(Self::expirer(inner, poke));
|
tokio::spawn(Self::expirer(inner, poke, clock));
|
||||||
}
|
}
|
||||||
|
|
||||||
VolatileCodeStore { inner, poke }
|
VolatileCodeStore { inner, poke }
|
||||||
@ -198,19 +200,15 @@ impl Default for VolatileCodeStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl VolatileCodeStore {
|
impl VolatileCodeStore {
|
||||||
async fn expirer(inner: Arc<Mutex<VolatileCodeStoreInner>>, poke: Arc<Notify>) {
|
async fn expirer(inner: Arc<Mutex<VolatileCodeStoreInner>>, poke: Arc<Notify>, clock: Clock) {
|
||||||
let mut next_expiry: Option<u64> = None;
|
let mut next_expiry: Option<i64> = None;
|
||||||
loop {
|
loop {
|
||||||
match next_expiry {
|
match next_expiry {
|
||||||
Some(next_expiry) => {
|
Some(next_expiry) => {
|
||||||
let sleep_until = UNIX_EPOCH + Duration::from_secs(next_expiry);
|
let sleep_future = clock.sleep_until(next_expiry);
|
||||||
let now = SystemTime::now();
|
|
||||||
let sleep_for = sleep_until
|
|
||||||
.duration_since(now)
|
|
||||||
.unwrap_or(Duration::from_secs(60));
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = poke.notified() => {},
|
_ = poke.notified() => {},
|
||||||
_ = tokio::time::sleep(sleep_for) => {},
|
_ = sleep_future => {},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
@ -218,14 +216,9 @@ impl VolatileCodeStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let now = SystemTime::now();
|
|
||||||
next_expiry = {
|
next_expiry = {
|
||||||
let mut inner = inner.lock().unwrap();
|
let mut inner = inner.lock().unwrap();
|
||||||
inner.handle_expiry(
|
inner.handle_expiry(clock.now_timestamp())
|
||||||
now.duration_since(UNIX_EPOCH)
|
|
||||||
.expect("system clock before unix epoch")
|
|
||||||
.as_secs(),
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -244,7 +237,7 @@ impl VolatileCodeStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new redeemable authorisation code.
|
/// Add a new redeemable authorisation code.
|
||||||
pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: u64) {
|
pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: i64) {
|
||||||
let mut inner = self.inner.lock().unwrap();
|
let mut inner = self.inner.lock().unwrap();
|
||||||
inner.add_redeemable(auth_code, binding, expires_at);
|
inner.add_redeemable(auth_code, binding, expires_at);
|
||||||
drop(inner);
|
drop(inner);
|
||||||
@ -284,7 +277,7 @@ mod test {
|
|||||||
use rstest::{fixture, rstest};
|
use rstest::{fixture, rstest};
|
||||||
use sqlx::types::Uuid;
|
use sqlx::types::Uuid;
|
||||||
|
|
||||||
use crate::web::oauth_openid::ext_codes::CodeRedemption;
|
use crate::{utils::Clock, web::oauth_openid::ext_codes::CodeRedemption};
|
||||||
|
|
||||||
use super::{AuthCode, AuthCodeBinding, VolatileCodeStore, VolatileCodeStoreInner};
|
use super::{AuthCode, AuthCodeBinding, VolatileCodeStore, VolatileCodeStoreInner};
|
||||||
|
|
||||||
@ -422,7 +415,8 @@ mod test {
|
|||||||
/// We can't easily cover everything here but may as well test the basics.
|
/// We can't easily cover everything here but may as well test the basics.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_fullfat_store_basic() {
|
async fn test_fullfat_store_basic() {
|
||||||
let vcs = VolatileCodeStore::default();
|
let clock = Clock::new_test();
|
||||||
|
let vcs = VolatileCodeStore::new(clock.clone());
|
||||||
|
|
||||||
vcs.add_redeemable(
|
vcs.add_redeemable(
|
||||||
VALID_CODE.clone(),
|
VALID_CODE.clone(),
|
||||||
@ -435,7 +429,7 @@ mod test {
|
|||||||
user_id: USER_UUID,
|
user_id: USER_UUID,
|
||||||
user_login_session_id: USER_LOGIN_SESSION_ID,
|
user_login_session_id: USER_LOGIN_SESSION_ID,
|
||||||
},
|
},
|
||||||
u64::MAX,
|
i64::MAX,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
@ -458,6 +452,7 @@ mod test {
|
|||||||
// given a moment
|
// given a moment
|
||||||
1,
|
1,
|
||||||
);
|
);
|
||||||
|
clock.set_time(2);
|
||||||
|
|
||||||
vcs.add_redeemable(
|
vcs.add_redeemable(
|
||||||
AuthCode([0; 24]),
|
AuthCode([0; 24]),
|
||||||
@ -470,11 +465,11 @@ mod test {
|
|||||||
user_id: USER_UUID,
|
user_id: USER_UUID,
|
||||||
user_login_session_id: USER_LOGIN_SESSION_ID,
|
user_login_session_id: USER_LOGIN_SESSION_ID,
|
||||||
},
|
},
|
||||||
u64::MAX,
|
i64::MAX,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Give a short time for the expiry to take place.
|
// Give a short time for the expiry to take place.
|
||||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
tokio::time::sleep(Duration::from_millis(2)).await;
|
||||||
|
|
||||||
assert_matches!(
|
assert_matches!(
|
||||||
vcs.redeem(&VALID_CODE, [1; 32], [2; 32]),
|
vcs.redeem(&VALID_CODE, [1; 32], [2; 32]),
|
||||||
|
|||||||
@ -1,10 +1,6 @@
|
|||||||
//! `/oidc/token`
|
//! `/oidc/token`
|
||||||
|
|
||||||
use std::{
|
use std::{str::FromStr, sync::Arc};
|
||||||
str::FromStr,
|
|
||||||
sync::Arc,
|
|
||||||
time::{SystemTime, UNIX_EPOCH},
|
|
||||||
};
|
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::rejection::FormRejection,
|
extract::rejection::FormRejection,
|
||||||
@ -15,13 +11,13 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
|
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
|
||||||
use blake2::Blake2s256;
|
use blake2::Blake2s256;
|
||||||
use chrono::{Duration, Utc};
|
use chrono::Duration;
|
||||||
use eyre::{bail, Context};
|
use eyre::{bail, Context};
|
||||||
use josekit::{
|
use josekit::{
|
||||||
jws::{alg::rsassa::RsassaJwsAlgorithm::Rs256, JwsHeader},
|
jws::{alg::rsassa::RsassaJwsAlgorithm::Rs256, JwsHeader},
|
||||||
jwt::JwtPayload,
|
jwt::JwtPayload,
|
||||||
};
|
};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::Rng;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use subtle::ConstantTimeEq;
|
use subtle::ConstantTimeEq;
|
||||||
@ -30,6 +26,7 @@ use tracing::{debug, error};
|
|||||||
use crate::{
|
use crate::{
|
||||||
config::{Configuration, SecretConfig},
|
config::{Configuration, SecretConfig},
|
||||||
store::IdCoopStore,
|
store::IdCoopStore,
|
||||||
|
utils::{Clock, RandGen},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::ext_codes::{
|
use super::ext_codes::{
|
||||||
@ -53,12 +50,15 @@ pub struct TokenFormParams {
|
|||||||
/// OpenID Connect clients call this to exchange an authorisation code they received for an access token.
|
/// OpenID Connect clients call this to exchange an authorisation code they received for an access token.
|
||||||
///
|
///
|
||||||
/// TODO auth_header can be one alternative auth method
|
/// TODO auth_header can be one alternative auth method
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn oidc_token(
|
pub async fn oidc_token(
|
||||||
basic_auth: Option<TypedHeader<Authorization<Basic>>>,
|
basic_auth: Option<TypedHeader<Authorization<Basic>>>,
|
||||||
Extension(config): Extension<Arc<Configuration>>,
|
Extension(config): Extension<Arc<Configuration>>,
|
||||||
Extension(secrets): Extension<Arc<SecretConfig>>,
|
Extension(secrets): Extension<Arc<SecretConfig>>,
|
||||||
Extension(store): Extension<Arc<IdCoopStore>>,
|
Extension(store): Extension<Arc<IdCoopStore>>,
|
||||||
Extension(code_store): Extension<VolatileCodeStore>,
|
Extension(code_store): Extension<VolatileCodeStore>,
|
||||||
|
Extension(mut randgen): Extension<RandGen>,
|
||||||
|
Extension(clock): Extension<Clock>,
|
||||||
form: Result<Form<TokenFormParams>, FormRejection>,
|
form: Result<Form<TokenFormParams>, FormRejection>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
let form = match form {
|
let form = match form {
|
||||||
@ -110,8 +110,9 @@ pub async fn oidc_token(
|
|||||||
Json(TokenError {
|
Json(TokenError {
|
||||||
code: TokenErrorCode::InvalidClient,
|
code: TokenErrorCode::InvalidClient,
|
||||||
description: "That `client_id` is not recognised here.".to_string(),
|
description: "That `client_id` is not recognised here.".to_string(),
|
||||||
})
|
}),
|
||||||
).into_response();
|
)
|
||||||
|
.into_response();
|
||||||
};
|
};
|
||||||
|
|
||||||
if !bool::from(
|
if !bool::from(
|
||||||
@ -181,10 +182,10 @@ pub async fn oidc_token(
|
|||||||
// Create an access token but don't actually issue it yet:
|
// Create an access token but don't actually issue it yet:
|
||||||
// This lets us store the hash of the access token against the redemption of the auth code,
|
// This lets us store the hash of the access token against the redemption of the auth code,
|
||||||
// so double redemptions can invalidate the access token appropriately.
|
// so double redemptions can invalidate the access token appropriately.
|
||||||
let access_token = thread_rng().gen::<AccessToken>();
|
let access_token = randgen.gen::<AccessToken>();
|
||||||
let access_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(access_token);
|
let access_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(access_token);
|
||||||
let access_token_hash: AccessTokenHash = Blake2s256::digest(access_token).into();
|
let access_token_hash: AccessTokenHash = Blake2s256::digest(access_token).into();
|
||||||
let refresh_token = thread_rng().gen::<RefreshToken>();
|
let refresh_token = randgen.gen::<RefreshToken>();
|
||||||
let refresh_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(refresh_token);
|
let refresh_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(refresh_token);
|
||||||
let refresh_token_hash: RefreshTokenHash = Blake2s256::digest(refresh_token).into();
|
let refresh_token_hash: RefreshTokenHash = Blake2s256::digest(refresh_token).into();
|
||||||
|
|
||||||
@ -256,7 +257,9 @@ pub async fn oidc_token(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. Check the code challenge
|
// 2. Check the code challenge
|
||||||
let Some(computed_code_challenge) = compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier) else {
|
let Some(computed_code_challenge) =
|
||||||
|
compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier)
|
||||||
|
else {
|
||||||
return (
|
return (
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(TokenError {
|
Json(TokenError {
|
||||||
@ -287,28 +290,29 @@ pub async fn oidc_token(
|
|||||||
let user_login_session_id = binding.user_login_session_id;
|
let user_login_session_id = binding.user_login_session_id;
|
||||||
|
|
||||||
// Issue access token for a new session
|
// Issue access token for a new session
|
||||||
|
let clock2 = clock.clone();
|
||||||
match store
|
match store
|
||||||
.txn(move |mut txn| {
|
.txn(move |mut txn| {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let Some(session_id) = txn
|
let Some(session_id) = txn
|
||||||
.create_application_session(user_id, &application_id, user_login_session_id)
|
.create_application_session(user_id, &application_id, user_login_session_id)
|
||||||
.await
|
.await
|
||||||
.context("create_application_session")? else {
|
.context("create_application_session")?
|
||||||
return Ok(Err(
|
else {
|
||||||
(
|
return Ok(Err((
|
||||||
StatusCode::BAD_REQUEST,
|
StatusCode::BAD_REQUEST,
|
||||||
Json(TokenError {
|
Json(TokenError {
|
||||||
code: TokenErrorCode::InvalidGrant,
|
code: TokenErrorCode::InvalidGrant,
|
||||||
description: "Auth code has expired or was not valid.".to_owned(),
|
description: "Auth code has expired or was not valid.".to_owned(),
|
||||||
}),
|
}),
|
||||||
).into_response()
|
)
|
||||||
));
|
.into_response()));
|
||||||
};
|
};
|
||||||
txn.issue_access_token(
|
txn.issue_access_token(
|
||||||
&access_token_hash,
|
&access_token_hash,
|
||||||
session_id,
|
session_id,
|
||||||
// TODO(expiry) Support custom expiry, not 100 years
|
// TODO(expiry) Support custom expiry, not 100 years
|
||||||
Utc::now() + Duration::days(365 * 100),
|
clock2.now_utc() + Duration::days(365 * 100),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.context("issue_access_token")?;
|
.context("issue_access_token")?;
|
||||||
@ -316,7 +320,7 @@ pub async fn oidc_token(
|
|||||||
&refresh_token_hash,
|
&refresh_token_hash,
|
||||||
session_id,
|
session_id,
|
||||||
// TODO(expiry) Support custom expiry, not 100 years
|
// TODO(expiry) Support custom expiry, not 100 years
|
||||||
Utc::now() + Duration::days(365 * 100),
|
clock2.now_utc() + Duration::days(365 * 100),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.context("issue_refresh_token")?;
|
.context("issue_refresh_token")?;
|
||||||
@ -344,10 +348,7 @@ pub async fn oidc_token(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let now = SystemTime::now()
|
let now = clock.now_timestamp();
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.expect("Before unix epoch?")
|
|
||||||
.as_secs();
|
|
||||||
// TODO(expiry) Support custom expiry times (not just 100 years)
|
// TODO(expiry) Support custom expiry times (not just 100 years)
|
||||||
let exp = now + 100 * 365 * 86400;
|
let exp = now + 100 * 365 * 86400;
|
||||||
let sub = binding.user_id.hyphenated().to_string();
|
let sub = binding.user_id.hyphenated().to_string();
|
||||||
@ -385,7 +386,9 @@ pub async fn oidc_token(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn make_id_token(id_token: IdToken, secrets: &SecretConfig) -> eyre::Result<String> {
|
fn make_id_token(id_token: IdToken, secrets: &SecretConfig) -> eyre::Result<String> {
|
||||||
let Ok(serde_json::Value::Object(map)) = serde_json::to_value(id_token).context("failed to serialise ID Token content") else {
|
let Ok(serde_json::Value::Object(map)) =
|
||||||
|
serde_json::to_value(id_token).context("failed to serialise ID Token content")
|
||||||
|
else {
|
||||||
bail!("ID Token not a map");
|
bail!("ID Token not a map");
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -447,9 +450,9 @@ pub struct IdToken {
|
|||||||
/// Implementers MAY provide for some small leeway, usually no more than a few minutes, to account for clock skew.
|
/// Implementers MAY provide for some small leeway, usually no more than a few minutes, to account for clock skew.
|
||||||
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
||||||
/// See RFC 3339 [RFC3339] for details regarding date/times in general and UTC in particular.
|
/// See RFC 3339 [RFC3339] for details regarding date/times in general and UTC in particular.
|
||||||
pub exp: u64,
|
pub exp: i64,
|
||||||
/// REQUIRED. Time at which the JWT was issued. Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
/// REQUIRED. Time at which the JWT was issued. Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
||||||
pub iat: u64,
|
pub iat: i64,
|
||||||
/// Time when the End-User authentication occurred.
|
/// Time when the End-User authentication occurred.
|
||||||
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
|
||||||
/// When a max_age request is made or when auth_time is requested as an Essential Claim, then this Claim is REQUIRED; otherwise, its inclusion is OPTIONAL.
|
/// When a max_age request is made or when auth_time is requested as an Essential Claim, then this Claim is REQUIRED; otherwise, its inclusion is OPTIONAL.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user