Merge pull request 'Add tests to 66% coverage' (#1)

Reviewed-on: #1
This commit is contained in:
Olivier 'reivilibre' 2024-07-07 12:24:28 +00:00
commit a21b23add3
39 changed files with 2162 additions and 297 deletions

358
Cargo.lock generated
View File

@ -152,6 +152,12 @@ dependencies = [
"password-hash",
]
[[package]]
name = "assert_matches2"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15832d94c458da98cac0ffa6eca52cc19c2a3c6c951058500a5ae8f01f0fdf56"
[[package]]
name = "async-recursion"
version = "1.0.5"
@ -265,6 +271,24 @@ dependencies = [
"syn 2.0.38",
]
[[package]]
name = "axum-test-helper"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "298f62fa902c2515c169ab0bfb56c593229f33faa01131215d58e3d4898e3aa9"
dependencies = [
"axum",
"bytes",
"http",
"http-body",
"hyper",
"reqwest",
"serde",
"tokio",
"tower",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.69"
@ -301,7 +325,7 @@ dependencies = [
"quote",
"rustc-hash",
"syn 2.0.38",
"toml_edit",
"toml_edit 0.19.15",
]
[[package]]
@ -595,6 +619,18 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "console"
version = "0.15.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb"
dependencies = [
"encode_unicode",
"lazy_static",
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "const-oid"
version = "0.9.5"
@ -854,6 +890,21 @@ dependencies = [
"serde",
]
[[package]]
name = "encode_unicode"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
[[package]]
name = "encoding_rs"
version = "0.8.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59"
dependencies = [
"cfg-if",
]
[[package]]
name = "equivalent"
version = "1.0.1"
@ -1093,9 +1144,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.2.0"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652"
checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456"
dependencies = [
"percent-encoding",
]
@ -1291,6 +1342,25 @@ dependencies = [
"smallvec",
]
[[package]]
name = "h2"
version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http",
"indexmap 2.0.2",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@ -1527,6 +1597,7 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"httparse",
@ -1568,9 +1639,11 @@ name = "idcoop"
version = "0.0.1"
dependencies = [
"argon2",
"assert_matches2",
"async-trait",
"axum",
"axum-client-ip",
"axum-test-helper",
"base64",
"blake2",
"chrono",
@ -1581,11 +1654,16 @@ dependencies = [
"futures",
"governor",
"hornbeam",
"insta",
"josekit",
"maplit",
"metrics",
"metrics-exporter-prometheus",
"metrics-process",
"pgtemp",
"rand",
"rand_xoshiro",
"rstest",
"serde",
"serde_json",
"serde_urlencoded",
@ -1602,9 +1680,9 @@ dependencies = [
[[package]]
name = "idna"
version = "0.4.0"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c"
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
@ -1673,6 +1751,19 @@ dependencies = [
"libc",
]
[[package]]
name = "insta"
version = "1.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "810ae6042d48e2c9e9215043563a58a80b877bc863228a74cf10c49d4620a6f5"
dependencies = [
"console",
"lazy_static",
"linked-hash-map",
"serde",
"similar",
]
[[package]]
name = "instant"
version = "0.1.12"
@ -1851,6 +1942,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]]
name = "linux-raw-sys"
version = "0.1.4"
@ -1888,6 +1985,12 @@ dependencies = [
"libc",
]
[[package]]
name = "maplit"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d"
[[package]]
name = "matchers"
version = "0.1.0"
@ -2004,6 +2107,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -2285,9 +2398,9 @@ dependencies = [
[[package]]
name = "percent-encoding"
version = "2.3.0"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pest"
@ -2366,6 +2479,18 @@ dependencies = [
"indexmap 2.0.2",
]
[[package]]
name = "pgtemp"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc39977b03503cfcf74ce2f8938d7b8340d1f8ef31a3f2a289b1c1abff8ace2c"
dependencies = [
"libc",
"tempfile",
"tokio",
"url",
]
[[package]]
name = "pin-project"
version = "1.1.3"
@ -2449,6 +2574,15 @@ version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro-crate"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284"
dependencies = [
"toml_edit 0.21.1",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.20+deprecated"
@ -2532,6 +2666,15 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rand_xoshiro"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
dependencies = [
"rand_core",
]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
@ -2603,6 +2746,51 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3cbb081b9784b07cceb8824c8583f86db4814d172ab043f3c23f7dc600bf83d"
[[package]]
name = "relative-path"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
[[package]]
name = "reqwest"
version = "0.11.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
dependencies = [
"base64",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"ipnet",
"js-sys",
"log",
"mime",
"mime_guess",
"once_cell",
"percent-encoding",
"pin-project-lite",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"system-configuration",
"tokio",
"tokio-util",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"winreg",
]
[[package]]
name = "rlimit"
version = "0.10.1"
@ -2632,6 +2820,36 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rstest"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9afd55a67069d6e434a95161415f5beeada95a01c7b815508a82dcb0e1593682"
dependencies = [
"futures",
"futures-timer",
"rstest_macros",
"rustc_version",
]
[[package]]
name = "rstest_macros"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4165dfae59a39dd41d8dec720d3cbfbc71f69744efb480a3920f5d4e0cc6798d"
dependencies = [
"cfg-if",
"glob",
"proc-macro-crate",
"proc-macro2",
"quote",
"regex",
"relative-path",
"rustc_version",
"syn 2.0.38",
"unicode-ident",
]
[[package]]
name = "rustc-demangle"
version = "0.1.23"
@ -2644,6 +2862,15 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "0.36.15"
@ -2736,6 +2963,12 @@ version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ef965a420fe14fdac7dd018862966a4c14094f900e1650bbc71ddd7d580c8af"
[[package]]
name = "semver"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]]
name = "serde"
version = "1.0.188"
@ -2827,6 +3060,15 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7cee0529a6d40f580e7a5e6c495c8fbfe21b7b52795ed4bb5e62cdf92bc6380"
[[package]]
name = "signal-hook-registry"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1"
dependencies = [
"libc",
]
[[package]]
name = "signature"
version = "2.1.0"
@ -2837,6 +3079,12 @@ dependencies = [
"rand_core",
]
[[package]]
name = "similar"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa42c91313f1d05da9b26f267f931cf178d4aba455b4c4622dd7355eb80c6640"
[[package]]
name = "sketches-ddsketch"
version = "0.2.1"
@ -3210,6 +3458,27 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "system-configuration"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tempfile"
version = "3.8.1"
@ -3325,7 +3594,10 @@ dependencies = [
"bytes",
"libc",
"mio",
"num_cpus",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2 0.5.4",
"tokio-macros",
"windows-sys 0.48.0",
@ -3353,6 +3625,19 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]]
name = "toml"
version = "0.5.11"
@ -3379,6 +3664,17 @@ dependencies = [
"winnow",
]
[[package]]
name = "toml_edit"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1"
dependencies = [
"indexmap 2.0.2",
"toml_datetime",
"winnow",
]
[[package]]
name = "tower"
version = "0.4.13"
@ -3576,6 +3872,15 @@ dependencies = [
"unic-langid-impl",
]
[[package]]
name = "unicase"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89"
dependencies = [
"version_check",
]
[[package]]
name = "unicode-bidi"
version = "0.3.13"
@ -3617,9 +3922,9 @@ checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "url"
version = "2.4.1"
version = "2.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5"
checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c"
dependencies = [
"form_urlencoded",
"idna",
@ -3716,6 +4021,18 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03"
dependencies = [
"cfg-if",
"js-sys",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.87"
@ -3745,6 +4062,19 @@ version = "0.2.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1"
[[package]]
name = "wasm-streams"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "web-sys"
version = "0.3.64"
@ -4025,6 +4355,16 @@ dependencies = [
"memchr",
]
[[package]]
name = "winreg"
version = "0.50.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1"
dependencies = [
"cfg-if",
"windows-sys 0.48.0",
]
[[package]]
name = "zeroize"
version = "1.6.0"

View File

@ -27,6 +27,7 @@ metrics = "0.21.1"
metrics-exporter-prometheus = "0.12.1"
metrics-process = "1.0.12"
rand = "0.8.5"
rand_xoshiro = "0.6.0"
serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.108"
serde_urlencoded = "0.7.1"
@ -39,3 +40,12 @@ tower-cookies = "0.9.0"
tower-http = { version = "0.4.4", features = ["trace", "cors", "set-header"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
[dev-dependencies]
assert_matches2 = "0.1.2"
axum-test-helper = "0.3.0"
insta = { version = "1.39.0", features = ["serde", "yaml"] }
maplit = "1.0.2"
pgtemp = "0.3.0"
rand_xoshiro = "0.6.0"
rstest = "0.21.0"

View File

@ -10,4 +10,4 @@
# Development
- [Testing](dev/testing.md)

45
docs/dev/testing.md Normal file
View File

@ -0,0 +1,45 @@
# Testing
## Testing approach
### Unit tests
Unit tests should live inside a module in the code unit they are testing,
gated behind `#[cfg(test)]`. \
This is fairly common Rust practice.
There is no hard and fast rule for the granularity of the unit tests,
but they should test the smallest amount of logic that is simple to test,
but no smaller. \
In practice this means that unit tests should be at the function-level
where this makes sense, or at the struct-level if this makes more sense.
For now, avoid the use of test mocks, but use them if it makes
strong sense to do so.
### Integration tests
Integration tests should live in the `tests/` directory.
In general, each test will get its own throwaway Postgres database.
#### Snapshot tests
Some of the integration tests will compare snapshots (of e.g. HTML) against
a gold standard.
When a new snapshot is created, the output should be manually verified,
including in a browser if necessary.
It goes without saying that all snapshot changes should be expected;
if they are not then treated as failures.
### End-to-end tests
idCoop doesn't currently have end-to-end tests but this is on the wishlist
for the future.
Will eventually look into Playwright etc.

View File

@ -74,6 +74,9 @@
# Test coverage. Vaguely useful but not definitive.
pkgs.cargo-tarpaulin
# Snapshot testing
pkgs.cargo-insta
pkgs.grass-sass
pkgs.entr

View File

@ -1,15 +1,12 @@
use std::io::stdin;
use std::sync::Arc;
use std::{net::SocketAddr, path::PathBuf};
use clap::Parser;
use comfy_table::presets::UTF8_FULL;
use comfy_table::{Attribute, Cell, Color, ContentArrangement, Row, Table};
use confique::{Config, Partial};
use eyre::{bail, Context, ContextCompat};
use eyre::{bail, Context};
use idcoop::cli::{handle_user_command, UserCommand};
use idcoop::config::{SecretConfig, SeparateSecretConfiguration};
use idcoop::passwords::create_password_hash;
use idcoop::store::{CreateUser, IdCoopStore};
use idcoop::store::IdCoopStore;
use idcoop::{config::Configuration, web};
use tracing_subscriber::fmt::format::FmtSpan;
use tracing_subscriber::layer::SubscriberExt;
@ -106,180 +103,15 @@ async fn main() -> eyre::Result<()> {
.await
.context("Failed to connect to Postgres")?;
let secrets = SecretConfig::try_new(&config).await?;
web::serve(bind, Arc::new(store), Arc::new(config), Arc::new(secrets)).await?
web::serve(bind, Arc::new(store), Arc::new(config), Arc::new(secrets)).await?;
}
Subcommand::User { cmd } => handle_user_command(cmd, &config).await?,
}
Ok(())
}
/// Commands for user management.
#[derive(Clone, Parser)]
enum UserCommand {
/// Add a user.
#[clap(alias = "new", alias = "create")]
Add {
/// The login name of the user.
// TODO this should be a richer newtype with validation
username: String,
#[clap(long = "locked")]
locked: bool,
},
/// Deletes a user.
/// Consider whether this is what you really want: in most cases locking a user is more appropriate.
#[clap(alias = "remove", alias = "rm", alias = "del")]
Delete {
/// The login name of the user.
username: String,
},
/// Locks a user, preventing them from logging in.
Lock {
/// The login name of the user.
username: String,
},
/// Unlocks a user, letting them log in once more.
Unlock {
/// The login name of the user.
username: String,
},
/// Changes a user's password.
#[clap(alias = "chpass", alias = "passwd")]
ChangePassword {
/// The login name of the user.
username: String,
},
/// Lists all users that are registered.
#[clap(alias = "ls")]
ListAll {
/// Only show a list of usernames, without table formatting characters and one per line. May be useful in scripts.
#[clap(long = "usernames")]
usernames: bool,
},
}
async fn handle_user_command(command: UserCommand, config: &Configuration) -> eyre::Result<()> {
let store = IdCoopStore::connect(&config.postgres.connect)
.await
.context("Failed to connect to Postgres")?;
match command {
UserCommand::Add { username, locked } => {
store
.txn(|mut txn| {
Box::pin(async move {
txn.create_user(CreateUser {
user_login_name: username,
password_hash: None,
locked,
})
.await
})
})
Subcommand::User { cmd } => {
let store = IdCoopStore::connect(&config.postgres.connect)
.await
.context("failed to add user")?;
}
UserCommand::Delete { username } => {
store
.txn(|mut txn| {
Box::pin(async move {
let user_id = txn
.lookup_user_by_name(username)
.await?
.context("No user by that name")?
.user_id;
txn.delete_user(user_id).await
})
})
.await?;
}
UserCommand::Lock { username } => {
store
.txn(|mut txn| {
Box::pin(async move {
let user_id = txn
.lookup_user_by_name(username)
.await?
.context("No user by that name")?
.user_id;
txn.set_user_locked(user_id, true).await
})
})
.await?;
}
UserCommand::Unlock { username } => {
store
.txn(|mut txn| {
Box::pin(async move {
let user_id = txn
.lookup_user_by_name(username)
.await?
.context("No user by that name")?
.user_id;
txn.set_user_locked(user_id, false).await
})
})
.await?;
}
UserCommand::ChangePassword { username } => {
let Some(user) = store.txn(|mut txn| { Box::pin(async move {
txn.lookup_user_by_name(username).await
})}).await? else {
bail!("No user by that name.");
};
println!("Change password for {} ({}):", user.user_name, user.user_id);
let mut buf_line = String::new();
stdin()
.read_line(&mut buf_line)
.context("failed to read password")?;
let raw_password = buf_line.trim();
let hash = create_password_hash(raw_password, &config.password_hashing)
.context("unable to hash password!")?;
store
.txn(|mut txn| {
Box::pin(
async move { txn.change_user_password(user.user_id, Some(hash)).await },
)
})
.await?;
}
UserCommand::ListAll { usernames } => {
let user_infos = store
.txn(|mut txn| Box::pin(async move { txn.list_user_info().await }))
.await?;
if usernames {
for user_info in user_infos {
println!("{}", user_info.user_name);
}
} else {
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.set_content_arrangement(ContentArrangement::Dynamic)
.set_width(80)
.set_header(vec![
Cell::new("Name").add_attribute(Attribute::Bold),
Cell::new("UUID").add_attribute(Attribute::Bold),
Cell::new("Locked").add_attribute(Attribute::Bold),
]);
for user_info in user_infos {
let mut row = Row::new();
row.add_cell(Cell::new(user_info.user_name));
row.add_cell(Cell::new(user_info.user_id).fg(Color::Grey));
let (lock_str, lock_colour) = if user_info.locked {
("yes", Color::Red)
} else {
("no", Color::White)
};
row.add_cell(Cell::new(lock_str).fg(lock_colour));
table.add_row(row);
}
println!("{}", table);
}
.context("Failed to connect to Postgres")?;
handle_user_command(cmd, &config, &store).await?;
}
}
Ok(())
}

185
src/cli.rs Normal file
View File

@ -0,0 +1,185 @@
//! idCoop Command Line Interface
use std::io::stdin;
use crate::config::Configuration;
use crate::passwords::create_password_hash;
use crate::store::{CreateUser, IdCoopStore};
use clap::Parser;
use comfy_table::presets::UTF8_FULL;
use comfy_table::{Attribute, Cell, Color, ContentArrangement, Row, Table};
use eyre::{bail, Context, ContextCompat};
/// Commands for user management.
#[derive(Clone, Parser)]
pub enum UserCommand {
/// Add a user.
#[clap(alias = "new", alias = "create")]
Add {
/// The login name of the user.
// TODO this should be a richer newtype with validation
username: String,
/// Set this flag if the user should be locked.
#[clap(long = "locked")]
locked: bool,
},
/// Deletes a user.
/// Consider whether this is what you really want: in most cases locking a user is more appropriate.
#[clap(alias = "remove", alias = "rm", alias = "del")]
Delete {
/// The login name of the user.
username: String,
},
/// Locks a user, preventing them from logging in.
Lock {
/// The login name of the user.
username: String,
},
/// Unlocks a user, letting them log in once more.
Unlock {
/// The login name of the user.
username: String,
},
/// Changes a user's password.
#[clap(alias = "chpass", alias = "passwd")]
ChangePassword {
/// The login name of the user.
username: String,
},
/// Lists all users that are registered.
#[clap(alias = "ls")]
ListAll {
/// Only show a list of usernames, without table formatting characters and one per line. May be useful in scripts.
#[clap(long = "usernames")]
usernames: bool,
},
}
/// Handles a user command from the command-line interface.
pub async fn handle_user_command(
command: UserCommand,
config: &Configuration,
store: &IdCoopStore,
) -> eyre::Result<()> {
match command {
UserCommand::Add { username, locked } => {
store
.txn(|mut txn| {
Box::pin(async move {
txn.create_user(CreateUser {
user_login_name: username,
password_hash: None,
locked,
})
.await
})
})
.await
.context("failed to add user")?;
}
UserCommand::Delete { username } => {
store
.txn(|mut txn| {
Box::pin(async move {
let user_id = txn
.lookup_user_by_name(username)
.await?
.context("No user by that name")?
.user_id;
txn.delete_user(user_id).await
})
})
.await?;
}
UserCommand::Lock { username } => {
store
.txn(|mut txn| {
Box::pin(async move {
let user_id = txn
.lookup_user_by_name(username)
.await?
.context("No user by that name")?
.user_id;
txn.set_user_locked(user_id, true).await
})
})
.await?;
}
UserCommand::Unlock { username } => {
store
.txn(|mut txn| {
Box::pin(async move {
let user_id = txn
.lookup_user_by_name(username)
.await?
.context("No user by that name")?
.user_id;
txn.set_user_locked(user_id, false).await
})
})
.await?;
}
UserCommand::ChangePassword { username } => {
let Some(user) = store
.txn(|mut txn| Box::pin(async move { txn.lookup_user_by_name(username).await }))
.await?
else {
bail!("No user by that name.");
};
println!("Change password for {} ({}):", user.user_name, user.user_id);
let mut buf_line = String::new();
stdin()
.read_line(&mut buf_line)
.context("failed to read password")?;
let raw_password = buf_line.trim();
let hash = create_password_hash(raw_password, &config.password_hashing)
.context("unable to hash password!")?;
store
.txn(|mut txn| {
Box::pin(
async move { txn.change_user_password(user.user_id, Some(hash)).await },
)
})
.await?;
}
UserCommand::ListAll { usernames } => {
let user_infos = store
.txn(|mut txn| Box::pin(async move { txn.list_user_info().await }))
.await?;
if usernames {
for user_info in user_infos {
println!("{}", user_info.user_name);
}
} else {
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.set_content_arrangement(ContentArrangement::Dynamic)
.set_width(80)
.set_header(vec![
Cell::new("Name").add_attribute(Attribute::Bold),
Cell::new("UUID").add_attribute(Attribute::Bold),
Cell::new("Locked").add_attribute(Attribute::Bold),
]);
for user_info in user_infos {
let mut row = Row::new();
row.add_cell(Cell::new(user_info.user_name));
row.add_cell(Cell::new(user_info.user_id).fg(Color::Grey));
let (lock_str, lock_colour) = if user_info.locked {
("yes", Color::Red)
} else {
("no", Color::White)
};
row.add_cell(Cell::new(lock_str).fg(lock_colour));
table.add_row(row);
}
println!("{}", table);
}
}
}
Ok(())
}

View File

@ -184,6 +184,7 @@ pub struct RatelimitsConfig {
/// - "5 Hz, 20 burst"
/// - "5 per second, 10 burst"
/// - "10 per hour, 5 burst"
#[derive(Debug)]
pub struct RatelimiterConfig {
/// The inner [`Quota`], which this struct is just a wrapper for.
pub quota: Quota,
@ -196,13 +197,19 @@ impl<'de> Deserialize<'de> for RatelimiterConfig {
{
let stringy_format = String::deserialize(deserializer)?;
let Some((left, right)) = stringy_format.split_once(',') else {
return Err(D::Error::custom("no comma. ratelimiter string should be like '5 Hz, 20 burst'"));
return Err(D::Error::custom(
"no comma. ratelimiter string should be like '5 Hz, 20 burst'",
));
};
let Some((left_val, left_unit)) = left.trim().split_once(' ') else {
return Err(D::Error::custom("no units on left. ratelimiter string should be like '5 Hz, 20 burst'"));
return Err(D::Error::custom(
"no units on left. ratelimiter string should be like '5 Hz, 20 burst'",
));
};
let Some((right_val, right_unit)) = right.trim().split_once(' ') else {
return Err(D::Error::custom("no units on right. ratelimiter string should be like '5 Hz, 20 burst'"));
return Err(D::Error::custom(
"no units on right. ratelimiter string should be like '5 Hz, 20 burst'",
));
};
let Ok(left_val) = left_val.parse::<NonZeroU32>() else {
@ -237,3 +244,62 @@ impl<'de> Deserialize<'de> for RatelimiterConfig {
Ok(RatelimiterConfig { quota })
}
}
#[cfg(test)]
mod test {
use crate::config::RatelimiterConfig;
#[test]
fn test_ratelimiter_deser_errors() {
fn deser(s: &str) -> Result<RatelimiterConfig, serde_json::Error> {
serde_json::from_value(serde_json::Value::String(s.to_owned()))
}
// this is fine
deser("5 Hz, 20 burst").unwrap();
// no comma
assert_eq!(
deser("5 Hz 20 burst").unwrap_err().to_string(),
"no comma. ratelimiter string should be like '5 Hz, 20 burst'"
);
// bad numbers
assert_eq!(
deser("five Hz, 20 burst").unwrap_err().to_string(),
"bad value on left. ratelimiter string should be like '5 Hz, 20 burst'"
);
assert_eq!(
deser("5 Hz, twenty burst").unwrap_err().to_string(),
"bad value on right. ratelimiter string should be like '5 Hz, 20 burst'"
);
// wrong order
assert_eq!(
deser("20 burst, 5 Hz").unwrap_err().to_string(),
"bad units on left. ratelimiter string should be like '5 Hz, 20 burst' or '5 per hour, 20 burst'."
);
// no units
assert_eq!(
deser("5 Hz, 20").unwrap_err().to_string(),
"no units on right. ratelimiter string should be like '5 Hz, 20 burst'"
);
assert_eq!(
deser("5, 20 burst").unwrap_err().to_string(),
"no units on left. ratelimiter string should be like '5 Hz, 20 burst'"
);
// bad units
assert_eq!(
deser("20 per milleniumm, 20 burst")
.unwrap_err()
.to_string(),
"bad units on left. ratelimiter string should be like '5 Hz, 20 burst' or '5 per hour, 20 burst'."
);
assert_eq!(
deser("5 Hz, 20 wombats").unwrap_err().to_string(),
"bad units on right. ratelimiter string should be like '5 Hz, 20 burst' or '5 per hour, 20 burst'."
);
}
}

View File

@ -5,7 +5,12 @@
#![deny(missing_docs)]
pub mod cli;
pub mod config;
pub mod passwords;
pub mod store;
pub mod utils;
pub mod web;
#[cfg(test)]
mod tests;

View File

@ -52,3 +52,43 @@ pub fn check_hash(password: &str, hash: &str) -> eyre::Result<bool> {
Ok(argon2.verify_password(password.as_bytes(), &hash).is_ok())
}
#[cfg(test)]
mod test {
use rstest::{fixture, rstest};
use crate::{config::PasswordHashingConfig, passwords::check_hash};
use super::create_password_hash;
/// Password hash for "secret"
const EXAMPLE_SECRET_PASSWORD_HASH: &str = "$argon2id$v=19$m=512,t=1,p=1$Z11PjkMSx/rm4IbDzDmK2Q$VtUH6Iee/GD1FltULyLf6/QRwjNA9d5+mjAKS+WzlZw";
/// Password hash for "verysecret"
const EXAMPLE_VERYSECRET_PASSWORD_HASH: &str = "$argon2id$v=19$m=512,t=1,p=1$tcGnJC8EOR4B43z35roeFg$o3grpVK9HAb3850iRI1/nXR+nPZbyFWchmkbkBhm7Co";
#[fixture]
fn password_config() -> PasswordHashingConfig {
PasswordHashingConfig {
memory: 512,
iterations: 1,
parallelism: 1,
}
}
#[rstest]
fn test_valid_password(password_config: PasswordHashingConfig) {
let pwh = create_password_hash("secret", &password_config).unwrap();
assert!(check_hash("secret", &pwh).unwrap());
assert!(check_hash("secret", EXAMPLE_SECRET_PASSWORD_HASH).unwrap());
}
#[rstest]
fn test_invalid_password(password_config: PasswordHashingConfig) {
let pwh = create_password_hash("verysecret", &password_config).unwrap();
assert!(!check_hash("secret", &pwh).unwrap());
assert!(!check_hash("secret", EXAMPLE_VERYSECRET_PASSWORD_HASH).unwrap());
}
}

View File

@ -110,7 +110,7 @@ pub struct UserInfo {
/// A wrapper around a database transaction with some database methods on it.
pub struct IdCoopStoreTxn<'a, 'txn> {
txn: &'a mut Transaction<'txn, Postgres>,
pub(crate) txn: &'a mut Transaction<'txn, Postgres>,
}
impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
@ -341,7 +341,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
.await
.context("failed to lookup login session")?;
let Some(row) = row_opt else { return Ok(None); };
let Some(row) = row_opt else {
return Ok(None);
};
Ok(Some(LoginSession {
user_name: row.user_name,
@ -373,7 +375,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> {
.await
.context("failed to lookup application session")?;
let Some(row) = row_opt else { return Ok(None); };
let Some(row) = row_opt else {
return Ok(None);
};
Ok(Some(ApplicationSession {
application_session_id: row.session_id,

112
src/tests.rs Normal file
View File

@ -0,0 +1,112 @@
use std::sync::Arc;
use axum::Router;
use confique::{Config, Partial};
use josekit::jwk::alg::rsa::RsaKeyPair;
use metrics::atomics::AtomicU64;
use pgtemp::PgTempDB;
use serde_json::json;
use crate::{
config::{Configuration, SecretConfig},
store::IdCoopStore,
utils::{Clock, RandGen},
web::make_router,
};
struct TestSystem {
database: PgTempDB,
web: Router,
config: Arc<Configuration>,
store: Arc<IdCoopStore>,
clock: Clock,
}
const RSA_KEY_PAIR_PEM: &[u8] = include_bytes!("tests/keypair.pem");
const RSA_PUBLIC_KEY_PEM: &[u8] = include_bytes!("tests/publickey.crt");
mod test_cli;
mod test_oidc_auth_flow;
async fn basic_system() -> TestSystem {
let temp_db = pgtemp::PgTempDBBuilder::new()
.with_dbname("test_idcoop")
.start_async()
.await;
let store = IdCoopStore::connect(&temp_db.connection_uri())
.await
.expect("failed to connect to pgtemp db");
let config_partial: <Configuration as confique::Config>::Partial =
serde_json::from_value(json!({
"listen": {
// Not useful, not actually used in the tests
"bind": "127.0.0.1:1",
"public_base_uri": "http://idcoop.example.com",
"client_ip_source": "RightmostXForwardedFor",
},
"postgres": {
"connect": "postgres://not-used-in-tests"
},
"oidc": {
"issuer": "http://issuer.example.com",
"rsa_keypair": "not-used-in-tests",
"clients": {
"aclient": {
"redirect_uris": [
"http://aclient.example.com/redirect"
],
"name": "AClient",
"allow_user_classes": ["active"],
"secret": "secretA",
}
}
},
"password_hashing": {
// Use weak password hash settings; we're not testing Argon2 here,
// we just want it to be fast.
"memory": 512,
"iterations": 1,
},
"ratelimits": {
"login": "3 per hour, 2 burst",
},
}))
.expect("bad test config");
let config =
Configuration::from_partial(config_partial.with_fallback(Partial::default_values()))
.expect("failed to load builtin config");
let secrets = SecretConfig {
rsa_key_pair: RsaKeyPair::from_pem(RSA_KEY_PAIR_PEM)
.expect("failed to decode builtin RSA keypair"),
};
let config = Arc::new(config);
let store = Arc::new(store);
let clock = Clock(Arc::new(AtomicU64::new(0)));
let randgen = RandGen::new();
let router = make_router(
store.clone(),
config.clone(),
Arc::new(secrets),
clock.clone(),
randgen,
)
.await
.expect("failed to make router");
TestSystem {
database: temp_db,
web: router,
config,
store,
clock,
}
}
#[tokio::test]
async fn test_demo() {
let basic = basic_system().await;
}

28
src/tests/keypair.pem Normal file
View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDDu6acOa+3ae2S
0llp5oMsXjBMd5QJeQJCcY5Q9NAITF2U9VBwAiMf2wmaTZ1aWWFGSb/zWef7Hx1e
qNhsK9MYL+QdJih2I+KpMtDWm7hhy9FtCHVc9i1z9PruXb0om2jDWLuBkPdCqJZT
C58ObZKmgL4OH5F1Qv5JR/ZX21OjLolXPJo1sonLv9mlgufvhUmnC17onSSqFLBA
nhUedjbfdnLShkp0xa8G0nW7Ls7Idaxyo8M5S2M+azJyLI87eqjjfz0yIW0Am890
mRa81hO2D+YNcVA2wIE9MEI/ie480YxLQ0VHCjX4DcVir4ceExysYkdL8VK+U14g
NO67k4NVAgMBAAECggEABnseJyoZ0V7miOgCIemKClwMCVwkQLQLCRwtdCzG/p9Y
sef1g9/uPc3I4Z0USruO5v7mJi6h6cS7+jhpAhvpX3GmgfiTemXxyVxvYcvCLSrM
gmm3SR61npNMA7yC2OdcbqtvefjM1x4x7AoEeDvUkULOCDWvYUyYkuCZHYubl1mS
Rtcp9rxzky2tjdp8CHySBa9Kz9LEjWdFGky7g3vSyqZtw6tkK5CTwMPb9aHwiEq1
yDWCbqAPPnb300dXSqx7z3AcsxBi/lpCs79fQS1vSJ1/L9POpYxX0SMtffkR5+Bl
Mkg9dZUVenfbP4n40FdMypTTX2KJiMkc8+f0+Tz9QQKBgQDrdT7xs0lqMeZWp2rD
y2KLzRl0iO/+yKse+BgkeBggZ/Vh2TeF9ylTRYdmimxkJ8eZRipTr0F64BS/LPEk
RgWviuf9dUl0gyuhYTOJgUw1wcgtB6e+04UKEUsQW6JoNZekkM+xTa7FskLmlJ4J
zzowjF1lgJeEyX1tvWXKIVe+/QKBgQDUzy5nZoUqBrVTYxL2uadIUh8oiBKPU93U
Gz3DUq90yfDa7lFhwMRQRXfNqGUy6tshsaF4fT1b62hZDSz1OH3h/y1LKQOdF5kc
JJyk/4b7NJna16kwBLzWje5SjQKr51aQWU8JftZ5/8uck2j7vMi+mgwzpG45J7kv
Q1I5decBOQKBgFq1sKotB/uBfdukY91KXYy+VzAuEUd2x3YG3kYufhz97+riZCGY
NrN99cvrSBbNvHewMF5NBkzwRw3foob28vnN6dIbfVEFt6lUaSZwSYvsO9IdQOKj
Wn2ma+TBaK/89Y7QuzLzWoGPS3bJipj83M4XRWP1RmpBtbCxZqWYctWBAoGAMZPi
16wGsffGHpsiO+CcnDilkafByypar6N5DBwjTC4PsrF6vC9QjPLiKkNk8CvOyVa8
q3lh5hw9vyFWq/pxOUldn/j6Iorw3KGa7MWrCLMEdPtxKwKvi7ydHRZE3Q+UFyT3
SNsH1HxHTz74Yk1k5yK0XQOduisK9XvVmBVjr+ECgYEAyoSbo/1cyLKWgrIr0K/f
stiKL9SmBmYbaGaxtQToB5Hnqso7Hz5YEDlrcr8s1ukEFghgeNYuDYw3ZKKGGfZm
yVQKAt8ouoO8rfkLrtt0H+/0uJgouhewDEqf/O+MfzwDnFcT89J5ZTEf+9n6pjry
fuiQnuwEsPYGCCFuWWlrdHQ=
-----END PRIVATE KEY-----

9
src/tests/publickey.crt Normal file
View File

@ -0,0 +1,9 @@
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAw7umnDmvt2ntktJZaeaD
LF4wTHeUCXkCQnGOUPTQCExdlPVQcAIjH9sJmk2dWllhRkm/81nn+x8dXqjYbCvT
GC/kHSYodiPiqTLQ1pu4YcvRbQh1XPYtc/T67l29KJtow1i7gZD3QqiWUwufDm2S
poC+Dh+RdUL+SUf2V9tToy6JVzyaNbKJy7/ZpYLn74VJpwte6J0kqhSwQJ4VHnY2
33Zy0oZKdMWvBtJ1uy7OyHWscqPDOUtjPmsyciyPO3qo4389MiFtAJvPdJkWvNYT
tg/mDXFQNsCBPTBCP4nuPNGMS0NFRwo1+A3FYq+HHhMcrGJHS/FSvlNeIDTuu5OD
VQIDAQAB
-----END PUBLIC KEY-----

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "16"
content-type: text/plain; charset=utf-8
- No access token.

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "21"
content-type: text/plain; charset=utf-8
- Invalid access token.

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- content-length: "238"
content-type: text/html; charset=utf-8
set-cookie: __Host-SessionlessXsrf=HL4qRFKUlBqkrPTvAQ6z-w; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000
x-frame-options: DENY
- "<form method='POST'>UN<input type='text' name='username'> PW<input type='password' name='password'> <input type='hidden' name='xsrf' value='HL4qRFKUlBqkrPTvAQ6z-w'><button type='submit'>click here to login</button> (temporary form)</form>"

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "28"
content-type: text/plain; charset=utf-8
- Invalid application session.

View File

@ -0,0 +1,10 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- content-length: "55"
content-type: text/plain; charset=utf-8
location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M&code_challenge_method=S256&nonce=noncey"
set-cookie: __Host-LoginSession=Glh_a6j2xs7ryaJWefPsoW59L7xq6QokAzGh-zEcOxY; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000
x-frame-options: DENY
- Logged in. Redirecting you back to what you were doing.

View File

@ -0,0 +1,10 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, json)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "75"
content-type: application/json
- error: invalid_request
error_description: "`code` parameter missing."

View File

@ -0,0 +1,8 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- content-length: "288"
content-type: text/html; charset=utf-8
x-frame-options: DENY
- "hi <u>robert</u>, consent to <u>AClient</u>? <form method='POST'><input type='hidden' name='xsrf' value='0.JpKyqkWckzF6w6btxX2RXv_MlxgOfoYOZJknydValkc'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>"

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "124"
content-type: application/json
- "{\"error\":\"invalid_grant\",\"error_description\":\"Auth code has been redeemed multiple times! This could mean something nasty.\"}"

View File

@ -0,0 +1,10 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, json)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "77"
content-type: application/json
- error: invalid_request
error_description: "`code` parameter malformed."

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- content-length: "46"
content-type: text/plain; charset=utf-8
location: "http://aclient.example.com/redirect?code=UnLS_bGq0ZB4szozTRCJIG-37ibG08zK&state=wombat&iss=http%3A%2F%2Fissuer.example.com"
x-frame-options: DENY
- Authorisation succeeded; redirecting you back.

View File

@ -0,0 +1,10 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, json)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "84"
content-type: application/json
- error: invalid_request
error_description: "`code_verifier` parameter missing."

View File

@ -0,0 +1,14 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, json)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "803"
content-type: application/json
- access_token: pvgYf08qA_ctEIhMP4DFQzbxjiCx8qfgi4cATwGsH9Q
expires_in: 31536000
id_token: eyJ0eXAiOiJKV1QiLCJraWQiOiJ0aGVrZXkiLCJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwic3ViIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwIiwiYXVkIjoiYWNsaWVudCIsImV4cCI6MzE1MzYwMDAzMCwiaWF0IjozMCwiYXV0aF90aW1lIjowLCJub25jZSI6Im5vbmNleSJ9.QQEhDgAcF2vBg2J6ledDzk_4ks4GyquJgMSE4KUREtTUVZbLpfa52sro8lPiBnFPCOz_DkSfpm4OQq8429mwcqfoyS-uBjtgPq7eij7kOa3BTrb9eC8rScGuDX0wJ9XZV-v0f3dun_sYhvH3smLqPoTF4wxtgT5b_2SCmnuqL2cmKN-GFox4mjmdoPzQxhAyTKlj_HkGHQjkl-nP96-71QeM5KwyLQes_OWU2HSEt9uiemUsEr4pMPv-po7QkrU5p2sJc5udcaUAOtuV9tpt5qg8P9TPWYo4M1GbMsbTyWYDhsmtusNKB6N2srwZwB9QwgE4DxoeoqmKlJf0BGF4Bg
refresh_token: _AHNodjMMCQJw2Bq3bsvDSZnjI4FEB2DDMK8ZHgDej8
scope: openid
token_type: Bearer

View File

@ -0,0 +1,10 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, json)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "74"
content-type: application/json
- error: invalid_grant
error_description: Code challenge is invalid.

View File

@ -0,0 +1,11 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, json)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "92"
content-type: application/json
- name: robert
preferred_username: robert
sub: 00000000-0000-0000-0000-000000000000

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "505"
content-type: application/json
- "{\"issuer\":\"http://idcoop.example.com\",\"authorization_endpoint\":\"http://idcoop.example.com/oidc/auth\",\"token_endpoint\":\"http://idcoop.example.com/oidc/token\",\"userinfo_endpoint\":\"http://idcoop.example.com/oidc/userinfo\",\"jwks_uri\":\"http://idcoop.example.com/oidc/jwks\",\"scopes_supported\":[\"openid\"],\"response_types_supported\":[\"code\"],\"response_modes_supported\":[\"query\"],\"grant_types_supported\":[\"authorization_code\"],\"subject_types_supported\":[\"public\"],\"id_token_signing_alg_values_supported\":[\"RS256\"]}"

View File

@ -0,0 +1,9 @@
---
source: src/tests/test_oidc_auth_flow.rs
expression: "(headers, text)"
---
- access-control-allow-origin: "*"
access-control-expose-headers: "*"
content-length: "425"
content-type: application/json
- "{\"keys\":[{\"kty\":\"RSA\",\"n\":\"w7umnDmvt2ntktJZaeaDLF4wTHeUCXkCQnGOUPTQCExdlPVQcAIjH9sJmk2dWllhRkm_81nn-x8dXqjYbCvTGC_kHSYodiPiqTLQ1pu4YcvRbQh1XPYtc_T67l29KJtow1i7gZD3QqiWUwufDm2SpoC-Dh-RdUL-SUf2V9tToy6JVzyaNbKJy7_ZpYLn74VJpwte6J0kqhSwQJ4VHnY233Zy0oZKdMWvBtJ1uy7OyHWscqPDOUtjPmsyciyPO3qo4389MiFtAJvPdJkWvNYTtg_mDXFQNsCBPTBCP4nuPNGMS0NFRwo1-A3FYq-HHhMcrGJHS_FSvlNeIDTuu5ODVQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"thekey\",\"alg\":\"RS256\"}]}"

164
src/tests/test_cli.rs Normal file
View File

@ -0,0 +1,164 @@
use rstest::rstest;
use crate::cli::{handle_user_command, UserCommand};
use super::basic_system;
#[rstest]
#[tokio::test]
async fn test_cli_add_user() {
let sys = basic_system().await;
handle_user_command(
UserCommand::Add {
username: "jonathan".to_owned(),
locked: true,
},
&sys.config,
&sys.store,
)
.await
.unwrap();
let _: () = sys
.store
.txn(|mut txn| {
Box::pin(async move {
let user = txn.lookup_user_by_name("jonathan".to_owned()).await?;
assert!(user.unwrap().locked);
Ok(())
})
})
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_cli_lock_and_unlock_user() {
let sys = basic_system().await;
handle_user_command(
UserCommand::Add {
username: "jonathan".to_owned(),
locked: false,
},
&sys.config,
&sys.store,
)
.await
.unwrap();
let _: () = sys
.store
.txn(|mut txn| {
Box::pin(async move {
let user = txn.lookup_user_by_name("jonathan".to_owned()).await?;
assert!(!user.unwrap().locked);
Ok(())
})
})
.await
.unwrap();
handle_user_command(
UserCommand::Lock {
username: "jonathan".to_owned(),
},
&sys.config,
&sys.store,
)
.await
.unwrap();
let _: () = sys
.store
.txn(|mut txn| {
Box::pin(async move {
let user = txn.lookup_user_by_name("jonathan".to_owned()).await?;
assert!(user.unwrap().locked);
Ok(())
})
})
.await
.unwrap();
handle_user_command(
UserCommand::Unlock {
username: "jonathan".to_owned(),
},
&sys.config,
&sys.store,
)
.await
.unwrap();
let _: () = sys
.store
.txn(|mut txn| {
Box::pin(async move {
let user = txn.lookup_user_by_name("jonathan".to_owned()).await?;
assert!(!user.unwrap().locked);
Ok(())
})
})
.await
.unwrap();
}
#[rstest]
#[tokio::test]
async fn test_cli_del_user() {
let sys = basic_system().await;
handle_user_command(
UserCommand::Add {
username: "jonathan".to_owned(),
locked: true,
},
&sys.config,
&sys.store,
)
.await
.unwrap();
let _: () = sys
.store
.txn(|mut txn| {
Box::pin(async move {
let user = txn.lookup_user_by_name("jonathan".to_owned()).await?;
assert!(user.unwrap().locked);
Ok(())
})
})
.await
.unwrap();
handle_user_command(
UserCommand::Delete {
username: "jonathan".to_owned(),
},
&sys.config,
&sys.store,
)
.await
.unwrap();
let _: () = sys
.store
.txn(|mut txn| {
Box::pin(async move {
let user = txn.lookup_user_by_name("jonathan".to_owned()).await?;
assert!(user.is_none());
Ok(())
})
})
.await
.unwrap();
}

View File

@ -0,0 +1,472 @@
//! Tests the OpenID Connect auth flow
use std::collections::BTreeMap;
use axum::http::StatusCode;
use axum_test_helper::{TestClient, TestResponse};
use insta::assert_yaml_snapshot;
use maplit::btreemap;
use sqlx::types::Uuid;
use crate::{passwords::create_password_hash, tests::basic_system};
async fn dump_resp_text(
req_name: &str,
resp: TestResponse,
) -> (StatusCode, BTreeMap<String, String>, String) {
let status = resp.status();
// convert headers to a simple B-Tree map so they can be serialised in snapshots
// easily
let mut headers: BTreeMap<String, String> = resp
.headers()
.clone()
.into_iter()
// skip headers that are duplicate keys.
// We don't care about theme for our purposes.
.filter_map(|(k, v)| Some((k?.to_string(), v.to_str().unwrap().to_owned())))
.collect();
// Remove date because it's not stable across tests!
headers.remove("date");
// Remove vary because it has multiple values and we don't want to
// introduce instability into our tests by only allowing one through.
headers.remove("vary");
let text = resp.text().await;
eprintln!("=== Response for {req_name} ===");
eprintln!("Status: {status:?}");
eprintln!("Headers: {headers:#?}");
eprintln!("Body: {text:?}");
eprintln!("=== End of response ===");
(status, headers, text)
}
/// Tests the full flow...
#[tokio::test]
async fn test_full_flow() {
let sys = basic_system().await;
let uuid = Uuid::nil();
let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap();
let _: () = sys
.store
.txn(|txn| {
Box::pin(async move {
sqlx::query(
"INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id",
).bind("robert").bind(uuid).bind(pwhash).bind(false)
.fetch_one(&mut **txn.txn)
.await.unwrap();
Ok(())
})
})
.await
.unwrap();
let client = TestClient::new(sys.web);
///// These requests are on behalf of the user's browser /////
const CODE_VERIFIER: &str = "verifying";
// base64(sha256("verifying"))
const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M";
// 1. /auth request
let login_url = format!("/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3D{CODE_CHALLENGE}%26code_challenge_method%3DS256%26nonce%3Dnoncey");
let resp = client.get(&format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey")).send().await;
let (status, headers, _text) = dump_resp_text("1. /auth request", resp).await;
assert_eq!(status, 302);
assert_eq!(headers.get("location").unwrap(), &login_url);
// 2. /login request
let resp = client.get(&login_url).send().await;
let (status, headers, text) = dump_resp_text("2. /login request", resp).await;
assert_eq!(status, 200);
assert_yaml_snapshot!("2/login", (headers, text));
// 3. /login request with credentials
let resp = client
.post(&login_url)
.form(&btreemap! {
"username" => "robert",
"password" => "secret",
"xsrf" => "HL4qRFKUlBqkrPTvAQ6z-w",
})
.header("Cookie", "__Host-SessionlessXsrf=HL4qRFKUlBqkrPTvAQ6z-w")
// /login is rate-limited by IP source and needs an IP
.header("X-Forwarded-For", "0.0.0.0")
.send()
.await;
let (status, headers, text) = dump_resp_text("3. /login request with credentials", resp).await;
assert_eq!(status, 302);
let auth_loc = headers.get("location").unwrap().to_owned();
assert_yaml_snapshot!("3/login", (headers, text));
// 4. /auth request
let resp = client
.get(&auth_loc)
.header(
"Cookie",
"__Host-LoginSession=Glh_a6j2xs7ryaJWefPsoW59L7xq6QokAzGh-zEcOxY",
)
.send()
.await;
let (status, headers, text) = dump_resp_text("4. GET /auth after login", resp).await;
assert_eq!(status, 200);
assert_yaml_snapshot!("4/auth", (headers, text));
sys.clock.set_time(30);
// 5. /auth request with confirmation
let resp = client
.post(&auth_loc)
.header(
"Cookie",
"__Host-LoginSession=Glh_a6j2xs7ryaJWefPsoW59L7xq6QokAzGh-zEcOxY",
)
.form(&btreemap! {
"action" => "accept",
"xsrf" => "0.JpKyqkWckzF6w6btxX2RXv_MlxgOfoYOZJknydValkc",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("5. POST /auth after confirmation", resp).await;
assert_eq!(status, 302);
// Note this snapshot includes a Location: header back to the client application
assert_yaml_snapshot!("5/auth", (headers, text));
///// At this point, we make requests on behalf of the client /////
// 6. /token request
let resp = client
.post("/oidc/token")
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
.form(&btreemap! {
"code" => "UnLS_bGq0ZB4szozTRCJIG-37ibG08zK",
"code_verifier" => CODE_VERIFIER,
"grant_type" => "authorization_code",
"redirect_uri" => "http://aclient.example.com/redirect",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("6. POST /token", resp).await;
assert_eq!(status, 200);
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
assert_yaml_snapshot!("6/token", (headers, json));
// 7. /userinfo request
let resp = client
.get("/oidc/userinfo")
.header(
"Authorization",
"Bearer pvgYf08qA_ctEIhMP4DFQzbxjiCx8qfgi4cATwGsH9Q",
)
.send()
.await;
let (status, headers, text) = dump_resp_text("7. /userinfo", resp).await;
assert_eq!(status, 200);
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
assert_yaml_snapshot!("7/userinfo", (headers, json));
}
#[tokio::test]
async fn test_jwks_endpoint() {
let sys = basic_system().await;
let client = TestClient::new(sys.web);
let resp = client.get("/oidc/jwks").send().await;
let (status, headers, text) = dump_resp_text("/jwks", resp).await;
assert_eq!(status, 200);
assert_yaml_snapshot!((headers, text));
}
#[tokio::test]
async fn test_discovery_endpoint() {
let sys = basic_system().await;
let client = TestClient::new(sys.web);
let resp = client.get("/.well-known/openid-configuration").send().await;
let (status, headers, text) = dump_resp_text("discovery", resp).await;
assert_eq!(status, 200);
assert_yaml_snapshot!((headers, text));
}
#[tokio::test]
async fn test_userinfo_bad_auth() {
let sys = basic_system().await;
let client = TestClient::new(sys.web);
// 1. no auth token
let resp = client.get("/oidc/userinfo").send().await;
let (status, headers, text) = dump_resp_text("1. no auth token", resp).await;
assert_eq!(status, 401);
assert_yaml_snapshot!("1. no auth token", (headers, text));
// 2. malformed access token
let resp = client
.get("/oidc/userinfo")
.header("Authorization", "Bearer ++++")
.send()
.await;
let (status, headers, text) = dump_resp_text("2. malformed auth token", resp).await;
assert_eq!(status, 401);
assert_yaml_snapshot!("2. malformed auth token", (headers, text));
// 3. wrong access token
let resp = client
.get("/oidc/userinfo")
.header("Authorization", "Bearer aaaa")
.send()
.await;
let (status, headers, text) = dump_resp_text("3. wrong auth token", resp).await;
assert_eq!(status, 401);
assert_yaml_snapshot!("3. wrong auth token", (headers, text));
}
/// Tests error conditions in the /token endpoint
#[tokio::test]
async fn test_token_errors() {
let sys = basic_system().await;
let uuid = Uuid::nil();
let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap();
let _: () = sys
.store
.txn(|txn| {
Box::pin(async move {
sqlx::query(
"INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id",
).bind("robert").bind(uuid).bind(pwhash).bind(false)
.fetch_one(&mut **txn.txn)
.await.unwrap();
Ok(())
})
})
.await
.unwrap();
let client = TestClient::new(sys.web);
///// These requests are on behalf of the user's browser /////
const CODE_VERIFIER: &str = "verifying";
// base64(sha256("verifying"))
const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M";
// 1. /login request with credentials
let resp = client
.post("/login")
.form(&btreemap! {
"username" => "robert",
"password" => "secret",
"xsrf" => "HL4qRFKUlBqkrPTvAQ6z-w",
})
.header("Cookie", "__Host-SessionlessXsrf=HL4qRFKUlBqkrPTvAQ6z-w")
// /login is rate-limited by IP source and needs an IP
.header("X-Forwarded-For", "0.0.0.0")
.send()
.await;
let (status, _headers, _text) =
dump_resp_text("1. /login request with credentials", resp).await;
assert_eq!(status, 302);
let auth_loc = format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey");
// 2. /auth request with confirmation
let resp = client
.post(&auth_loc)
.header(
"Cookie",
"__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE",
)
.form(&btreemap! {
"action" => "accept",
"xsrf" => "0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y",
})
.send()
.await;
let (status, _headers, _text) = dump_resp_text("2. POST /auth after confirmation", resp).await;
assert_eq!(status, 302);
eprintln!("{:?}", _text);
///// At this point, we make requests on behalf of the client /////
// 3. /token request with no code
let resp = client
.post("/oidc/token")
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
.form(&btreemap! {
"code_verifier" => CODE_VERIFIER,
"grant_type" => "authorization_code",
"redirect_uri" => "http://aclient.example.com/redirect",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("3. /token no code", resp).await;
assert_eq!(status, 400);
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
assert_yaml_snapshot!("3/token_no_code", (headers, json));
// 4. /token request with malformed code (not long enough)
let resp = client
.post("/oidc/token")
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
.form(&btreemap! {
"code" => "aaaa",
"code_verifier" => CODE_VERIFIER,
"grant_type" => "authorization_code",
"redirect_uri" => "http://aclient.example.com/redirect",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("4. /token malformed code", resp).await;
assert_eq!(status, 400);
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
assert_yaml_snapshot!("4/token_malformed_code", (headers, json));
// 5. /token request with no code_verifier
let resp = client
.post("/oidc/token")
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
.form(&btreemap! {
"code" => "LRtIBH5rO3O7hwWaF_UkuFJy0v2xqtGQ",
"grant_type" => "authorization_code",
"redirect_uri" => "http://aclient.example.com/redirect",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("5. /token no verifier", resp).await;
assert_eq!(status, 400);
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
assert_yaml_snapshot!("5/token_no_verifier", (headers, json));
// 6. /token request with wrong code_verifier
let resp = client
.post("/oidc/token")
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
.form(&btreemap! {
"code" => "LRtIBH5rO3O7hwWaF_UkuFJy0v2xqtGQ",
"code_verifier" => "i'm wrong",
"grant_type" => "authorization_code",
"redirect_uri" => "http://aclient.example.com/redirect",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("6. /token wrong verifier", resp).await;
assert_eq!(status, 400);
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
assert_yaml_snapshot!("6/token_wrong_verifier", (headers, json));
}
/// Tests double-requesting a /token (auth code conflict)
#[tokio::test]
async fn test_token_conflict() {
let sys = basic_system().await;
let uuid = Uuid::nil();
let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap();
let _: () = sys
.store
.txn(|txn| {
Box::pin(async move {
sqlx::query(
"INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id",
).bind("robert").bind(uuid).bind(pwhash).bind(false)
.fetch_one(&mut **txn.txn)
.await.unwrap();
Ok(())
})
})
.await
.unwrap();
let client = TestClient::new(sys.web);
///// These requests are on behalf of the user's browser /////
const CODE_VERIFIER: &str = "verifying";
// base64(sha256("verifying"))
const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M";
// 1. /login request with credentials
let resp = client
.post("/login")
.form(&btreemap! {
"username" => "robert",
"password" => "secret",
"xsrf" => "HL4qRFKUlBqkrPTvAQ6z-w",
})
.header("Cookie", "__Host-SessionlessXsrf=HL4qRFKUlBqkrPTvAQ6z-w")
// /login is rate-limited by IP source and needs an IP
.header("X-Forwarded-For", "0.0.0.0")
.send()
.await;
let (status, _headers, _text) =
dump_resp_text("1. /login request with credentials", resp).await;
assert_eq!(status, 302);
let auth_loc = format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey");
// 2. /auth request with confirmation
let resp = client
.post(&auth_loc)
.header(
"Cookie",
"__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE",
)
.form(&btreemap! {
"action" => "accept",
"xsrf" => "0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y",
})
.send()
.await;
let (status, _headers, _text) = dump_resp_text("2. POST /auth after confirmation", resp).await;
assert_eq!(status, 302);
eprintln!("{:?}", _text);
///// At this point, we make requests on behalf of the client /////
// 3. /token request (successful)
let resp = client
.post("/oidc/token")
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
.form(&btreemap! {
"code" => "LRtIBH5rO3O7hwWaF_UkuFJy0v2xqtGQ",
"code_verifier" => CODE_VERIFIER,
"grant_type" => "authorization_code",
"redirect_uri" => "http://aclient.example.com/redirect",
})
.send()
.await;
let (status, _headers, text) = dump_resp_text("3. POST /token", resp).await;
assert_eq!(status, 200);
let json: BTreeMap<String, serde_json::Value> = serde_json::from_str(&text).unwrap();
let access_token = json
.get("access_token")
.unwrap()
.as_str()
.unwrap()
.to_owned();
// 4. /token request (conflicting)
let resp = client
.post("/oidc/token")
.header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB")
.form(&btreemap! {
"code" => "LRtIBH5rO3O7hwWaF_UkuFJy0v2xqtGQ",
"code_verifier" => CODE_VERIFIER,
"grant_type" => "authorization_code",
"redirect_uri" => "http://aclient.example.com/redirect",
})
.send()
.await;
let (status, headers, text) = dump_resp_text("4. POST /token (conflict)", resp).await;
assert_eq!(status, 400);
assert_yaml_snapshot!("4/token_conflict", (headers, text));
// 5. /userinfo (using the access token that should now have expired)
let resp = client
.get("/oidc/userinfo")
.header("Authorization", format!("Bearer {access_token}"))
.send()
.await;
let (status, _headers, _text) = dump_resp_text("7. /userinfo", resp).await;
assert_eq!(status, 401);
}

145
src/utils.rs Normal file
View File

@ -0,0 +1,145 @@
//! Miscellaneous utilities
#[cfg(not(test))]
pub use self::real_utils::{Clock, RandGen};
#[cfg(test)]
pub use self::test_utils::{Clock, RandGen};
#[cfg(not(test))]
mod real_utils {
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use chrono::{DateTime, Utc};
use rand::{thread_rng, RngCore};
/// A source of random numbers that can be faked for tests.
#[derive(Clone)]
pub struct RandGen;
impl RngCore for RandGen {
fn next_u32(&mut self) -> u32 {
thread_rng().next_u32()
}
fn next_u64(&mut self) -> u64 {
thread_rng().next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
thread_rng().fill_bytes(dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
thread_rng().try_fill_bytes(dest)
}
}
/// A source of time that can be faked for tests.
#[derive(Clone)]
pub struct Clock;
impl Clock {
/// Returns the current time as a `DateTime<Utc>`.
pub fn now_utc(&self) -> DateTime<Utc> {
Utc::now()
}
/// Returns the current time as a i64 of seconds since the Unix epoch.
pub fn now_timestamp(&self) -> i64 {
Utc::now().timestamp()
}
/// Sleep until the given timestamp.
pub async fn sleep_until(&self, until_ts: i64) {
if until_ts < 0 {
return;
}
let sleep_until = UNIX_EPOCH + Duration::from_secs(until_ts as u64);
let now = SystemTime::now();
let sleep_for = sleep_until
.duration_since(now)
.unwrap_or(Duration::from_secs(0));
tokio::time::sleep(sleep_for).await
}
}
}
#[cfg(test)]
mod test_utils {
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use chrono::{DateTime, TimeZone, Utc};
use rand::{RngCore, SeedableRng};
use rand_xoshiro::Xoshiro256StarStar;
#[derive(Clone)]
pub struct RandGen(Arc<Mutex<Xoshiro256StarStar>>);
impl RandGen {
#[allow(clippy::new_without_default)]
pub fn new() -> RandGen {
RandGen(Arc::new(Mutex::new(Xoshiro256StarStar::seed_from_u64(
424242,
))))
}
}
impl RngCore for RandGen {
fn next_u32(&mut self) -> u32 {
let mut rng = self.0.lock().unwrap();
rng.next_u32()
}
fn next_u64(&mut self) -> u64 {
let mut rng = self.0.lock().unwrap();
rng.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
let mut rng = self.0.lock().unwrap();
rng.fill_bytes(dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
let mut rng = self.0.lock().unwrap();
rng.try_fill_bytes(dest)
}
}
#[derive(Clone)]
pub struct Clock(pub Arc<AtomicU64>);
impl Clock {
pub fn new_test() -> Self {
Clock(Arc::new(AtomicU64::new(0)))
}
pub fn set_time(&self, new: u64) {
self.0.store(new, std::sync::atomic::Ordering::Relaxed);
}
pub fn now_utc(&self) -> DateTime<Utc> {
Utc.timestamp_opt(self.0.load(std::sync::atomic::Ordering::Relaxed) as i64, 0)
.earliest()
.unwrap()
}
pub fn now_timestamp(&self) -> i64 {
self.0.load(std::sync::atomic::Ordering::Relaxed) as i64
}
pub async fn sleep_until(&self, until_ts: i64) {
if until_ts < 0 {
return;
}
// Wait for time to advance past the requested timestamp
// TODO write a better test sleep implementation
while self.now_timestamp() < until_ts {
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
}
}

View File

@ -24,16 +24,16 @@ use axum::{
routing::{get, post},
Extension, Router,
};
use eyre::Context;
use governor::{clock::QuantaClock, state::keyed::DashMapStateStore, RateLimiter};
use hornbeam::{initialise_template_manager, make_template_manager};
use tower_cookies::CookieManagerLayer;
use tower_http::{cors::CorsLayer, set_header::SetResponseHeaderLayer, trace::TraceLayer};
use tracing::{error, info};
use tracing::error;
use crate::{
config::{Configuration, RatelimiterConfig, RatelimitsConfig, SecretConfig},
store::IdCoopStore,
utils::{Clock, RandGen},
web::{
login::{get_login, post_login, PasswordHashInflightLimiter},
oauth_openid::{
@ -55,14 +55,15 @@ make_template_manager! {
};
}
/// Serves, on the bind address specified, the HTTP service
/// including a user interface and any OAuth, OpenID Connect and custom APIs.
pub async fn serve(
bind: SocketAddr,
/// Make an axum `Router` but do not bind it to a port.
/// This exposition allows us to perform integration testing easily.
pub(crate) async fn make_router(
store: Arc<IdCoopStore>,
config: Arc<Configuration>,
secrets: Arc<SecretConfig>,
) -> eyre::Result<()> {
clock: Clock,
randgen: RandGen,
) -> eyre::Result<Router> {
initialise_template_manager!(TEMPLATING);
let client_ip_source = config.listen.client_ip_source.clone();
@ -124,7 +125,26 @@ pub async fn serve(
.layer(Extension(Arc::new(PasswordHashInflightLimiter::new(1))))
.layer(client_ip_source.into_extension())
.layer(Extension(Arc::new(ratelimiters)))
.layer(Extension(VolatileCodeStore::default()));
.layer(Extension(VolatileCodeStore::new(clock.clone())))
.layer(Extension(clock))
.layer(Extension(randgen));
Ok(router)
}
/// Serves, on the bind address specified, the HTTP service
/// including a user interface and any OAuth, OpenID Connect and custom APIs.
#[cfg(not(test))]
pub async fn serve(
bind: SocketAddr,
store: Arc<IdCoopStore>,
config: Arc<Configuration>,
secrets: Arc<SecretConfig>,
) -> eyre::Result<()> {
use eyre::Context;
use tracing::info;
let router = make_router(store, config, secrets, Clock, RandGen).await?;
info!("Listening on {bind:?}");
axum::Server::try_bind(&bind)
@ -213,7 +233,7 @@ impl Ratelimiters {
/// Do some housekeeping if it hasn't been done recently.
pub fn housekeeping(&self) {
let Ok(now) = SystemTime::now().duration_since(UNIX_EPOCH) else {
return
return;
};
let now = (now.as_secs() >> 2) as u32;

View File

@ -9,8 +9,8 @@ use std::{
use async_trait::async_trait;
use axum::{
extract::{FromRequestParts, Query},
headers::Cookie,
http::{request::Parts, uri::PathAndQuery, HeaderValue, StatusCode},
headers::Cookie as CookieHeader,
http::{request::Parts, uri::PathAndQuery, StatusCode},
response::{Html, IntoResponse, Response},
Extension, Form, TypedHeader,
};
@ -21,17 +21,18 @@ use chrono::{DateTime, Duration, TimeZone, Utc};
use eyre::eyre;
use eyre::{bail, Context, ContextCompat};
use governor::Jitter;
use rand::{thread_rng, Rng};
use rand::Rng;
use serde::Deserialize;
use sqlx::types::Uuid;
use tokio::sync::Semaphore;
use tower_cookies::Cookies;
use tower_cookies::{Cookie, Cookies};
use tracing::error;
use crate::{
config::{Configuration, PasswordHashingConfig},
passwords::{check_hash, create_password_hash},
store::IdCoopStore,
utils::RandGen,
};
use super::{sessionless_xsrf, Ratelimiters, WebResult};
@ -203,17 +204,15 @@ where
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Ok(cookies) = TypedHeader::<Cookie>::from_request_parts(parts, state).await else {
let Ok(cookies) = TypedHeader::<CookieHeader>::from_request_parts(parts, state).await
else {
return Err((StatusCode::UNAUTHORIZED, "No login session."));
};
let Some(cookie_val) = cookies.get("__Host-LoginSession").map(str::to_owned) else {
return Err((StatusCode::UNAUTHORIZED, "No login session."));
};
let Ok(login_session_token) = BASE64_URL_SAFE_NO_PAD.decode(&cookie_val) else {
return Err((
StatusCode::UNAUTHORIZED,
"Invalid login session token."
));
return Err((StatusCode::UNAUTHORIZED, "Invalid login session token."));
};
if login_session_token.len() != LOGIN_SESSION_TOKEN_BYTES {
return Err((StatusCode::UNAUTHORIZED, "Invalid login session token."));
@ -268,11 +267,12 @@ pub async fn get_login(
current_session: Option<LoginSession>,
Query(query): Query<LoginQuery>,
cookies: Cookies,
Extension(mut randgen): Extension<RandGen>,
) -> Response {
match current_session {
Some(_session) => make_post_login_redirect(query.then),
None => {
let xsrf_token = sessionless_xsrf::get_token(&cookies);
let xsrf_token = sessionless_xsrf::get_token(&cookies, &mut randgen);
Html(format!("<form method='POST'>UN<input type='text' name='username'> PW<input type='password' name='password'> <input type='hidden' name='xsrf' value='{}'><button type='submit'>click here to login</button> (temporary form)</form>", xsrf_token)).into_response()
}
}
@ -335,6 +335,7 @@ pub async fn post_login(
Extension(phil): Extension<Arc<PasswordHashInflightLimiter>>,
SecureClientIp(src_ip): SecureClientIp,
Extension(ratelimiters): Extension<Arc<Ratelimiters>>,
Extension(mut randgen): Extension<RandGen>,
Form(form): Form<PostLoginForm>,
) -> WebResult<Response> {
ratelimiters.housekeeping();
@ -344,7 +345,7 @@ pub async fn post_login(
.await;
if !sessionless_xsrf::check_token(&cookies, &form.xsrf) {
// Invalid XSRF token: try again
return Ok(get_login(None, Query(query), cookies).await);
return Ok(get_login(None, Query(query), cookies, Extension(randgen)).await);
}
// retrieve user details
@ -400,11 +401,11 @@ pub async fn post_login(
};
// Generate a login session token and store the hash in our database
let login_session_token = thread_rng().gen::<[u8; LOGIN_SESSION_TOKEN_BYTES]>();
let login_session_token = randgen.gen::<[u8; LOGIN_SESSION_TOKEN_BYTES]>();
let login_session_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(login_session_token);
let login_session_token_hash: [u8; LOGIN_SESSION_TOKEN_HASH_BYTES] =
Blake2s256::digest(login_session_token).into();
let xsrf_secret = thread_rng().gen::<[u8; LOGIN_SESSION_XSRF_SECRET_BYTES]>();
let xsrf_secret = randgen.gen::<[u8; LOGIN_SESSION_XSRF_SECRET_BYTES]>();
// store session in the database
store
@ -417,20 +418,16 @@ pub async fn post_login(
.await
.context("failed to store session in database")?;
let expiry_date = chrono::Utc::now() + chrono::Duration::days(500);
let expiry_date_rfc1123 = expiry_date.format("%a, %d %b %Y %H:%M:%S GMT");
Ok((
[(
"Set-Cookie",
HeaderValue::from_str(&format!(
"__Host-LoginSession={}; Path=/; HttpOnly; SameSite=Strict; Secure; Expires={}",
login_session_token_b64, expiry_date_rfc1123
))
.expect("no reason we should fail to make a cookie"),
)],
make_post_login_redirect(query.then),
)
.into_response())
cookies.add(
Cookie::build("__Host-LoginSession", login_session_token_b64.clone())
.path("/")
.http_only(true)
.secure(true)
.same_site(tower_cookies::cookie::SameSite::Strict)
.max_age(time::Duration::days(500))
.finish(),
);
Ok(make_post_login_redirect(query.then))
}
/// Make a redirect for once the user has logged in.

View File

@ -9,7 +9,6 @@ use axum::{
Extension, Form,
};
use chrono::Utc;
use eyre::{Context, ContextCompat};
use serde::{Deserialize, Serialize};
@ -17,6 +16,7 @@ use tracing::{error, warn};
use crate::{
config::{Configuration, OidcClientConfiguration},
utils::{Clock, RandGen},
web::{
login::LoginSession,
make_login_redirect,
@ -69,6 +69,8 @@ pub async fn oidc_authorisation(
login_session: Option<LoginSession>,
Extension(config): Extension<Arc<Configuration>>,
Extension(code_store): Extension<VolatileCodeStore>,
Extension(clock): Extension<Clock>,
Extension(mut randgen): Extension<RandGen>,
OriginalUri(uri): OriginalUri,
) -> Response {
let Query(query) = match query {
@ -109,7 +111,7 @@ pub async fn oidc_authorisation(
// If the application requires consent, then we should ask for that.
if !client_config.skip_consent {
return show_consent_page(login_session, client_config, &config).await;
return show_consent_page(login_session, client_config, Extension(clock), &config).await;
}
// No consent needed: process the authorisation.
@ -120,6 +122,8 @@ pub async fn oidc_authorisation(
client_id,
client_config,
&config,
&mut randgen,
&clock,
&code_store,
)
.await
@ -133,11 +137,14 @@ pub struct PostConsentForm {
}
/// `POST /oidc/auth`
#[allow(clippy::too_many_arguments)]
pub async fn post_oidc_authorisation_consent(
Query(query): Query<AuthorisationQuery>,
login_session: Option<LoginSession>,
Extension(config): Extension<Arc<Configuration>>,
Extension(code_store): Extension<VolatileCodeStore>,
Extension(clock): Extension<Clock>,
Extension(mut randgen): Extension<RandGen>,
OriginalUri(uri): OriginalUri,
Form(form): Form<PostConsentForm>,
) -> Response {
@ -152,11 +159,11 @@ pub async fn post_oidc_authorisation_consent(
};
if login_session
.validate_xsrf_token(&form.xsrf, Utc::now())
.validate_xsrf_token(&form.xsrf, clock.now_utc())
.is_err()
{
// XSRF token is not valid, so show the consent form again...
return show_consent_page(login_session, client_config, &config).await;
return show_consent_page(login_session, client_config, Extension(clock), &config).await;
}
match form.action.as_str() {
@ -167,6 +174,8 @@ pub async fn post_oidc_authorisation_consent(
client_id,
client_config,
&config,
&mut randgen,
&clock,
&code_store,
)
.await
@ -233,10 +242,11 @@ fn validate_authorisation_basics<'a>(
async fn show_consent_page(
login_session: LoginSession,
client_config: &OidcClientConfiguration,
Extension(clock): Extension<Clock>,
_config: &Configuration,
) -> Response {
let xsrf_token = login_session
.generate_xsrf_token(Utc::now())
.generate_xsrf_token(clock.now_utc())
.expect("must be able to create a XSRF token");
Html(format!(
"hi <u>{}</u>, consent to <u>{}</u>? <form method='POST'><input type='hidden' name='xsrf' value='{}'><button type='submit' name='action' value='accept'>Accept</button> <button type='submit' name='action' value='deny'>Deny</button></form>",
@ -253,12 +263,15 @@ async fn show_consent_page(
/// Preconditions:
/// - any required consent from the user has now been obtained
/// - query.request_uri has been validated as a safe redirect URI
#[allow(clippy::too_many_arguments)]
async fn process_authorisation(
query: AuthorisationQuery,
login_session: LoginSession,
client_id: String,
_client_config: &OidcClientConfiguration,
config: &Configuration,
randgen: &mut RandGen,
clock: &Clock,
code_store: &VolatileCodeStore,
) -> Response {
assert_eq!(
@ -271,7 +284,7 @@ async fn process_authorisation(
// Generate a 192-bit random code, which fits into exactly 32 base64 characters.
// This is an arbitrary choice left to us but I feel a 192-bit value is sufficiently random.
let code = AuthCode::generate_new_random();
let code = AuthCode::generate_new_random(randgen);
let code_base64url = code.to_string();
// Write down the code and other details in-memory with 10 minute expiry...
@ -289,7 +302,7 @@ async fn process_authorisation(
user_id: login_session.user_id,
user_login_session_id: login_session.login_session_id,
},
0,
clock.now_timestamp() + 600,
);
#[derive(Serialize)]

View File

@ -2,10 +2,10 @@
use std::{
collections::{BTreeSet, HashMap},
fmt::Debug,
fmt::Display,
str::FromStr,
sync::{Arc, Mutex},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use base64::{display::Base64Display, prelude::BASE64_URL_SAFE_NO_PAD, Engine};
@ -13,10 +13,12 @@ use rand::Rng;
use sqlx::types::Uuid;
use tokio::sync::Notify;
use crate::utils::{Clock, RandGen};
/// Display shows the auth code as base64 (URL-safe non-padded).
/// FromStr/parse parses the same format.
#[derive(Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
pub struct AuthCode([u8; 24]);
pub struct AuthCode(pub [u8; 24]);
/// Access token
pub type AccessToken = [u8; 32];
@ -37,6 +39,12 @@ impl Display for AuthCode {
}
}
impl Debug for AuthCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}
impl FromStr for AuthCode {
type Err = &'static str;
@ -54,14 +62,15 @@ impl FromStr for AuthCode {
impl AuthCode {
/// Generate a new authorisation code using the thread's RNG
pub fn generate_new_random() -> Self {
Self(rand::thread_rng().gen::<[u8; 24]>())
pub fn generate_new_random(randgen: &mut RandGen) -> Self {
Self(randgen.gen::<[u8; 24]>())
}
}
/// Binding between an authorisation code (ready to be redeemed)
/// and both the user that authenticated to produce it
/// as well as the OpenID Connect client that it is for.
#[derive(PartialEq, Eq, Debug)]
pub struct AuthCodeBinding {
/// ID of the OpenID Connect client
pub client_id: String,
@ -104,7 +113,7 @@ struct VolatileCodeStoreInner {
pub conflictable_codes: HashMap<AuthCode, RedeemedAuthCode>,
/// Time when codes will expire
pub expire_codes_at: BTreeSet<(u64, AuthCode)>,
pub expire_codes_at: BTreeSet<(i64, AuthCode)>,
}
impl VolatileCodeStoreInner {
@ -140,7 +149,7 @@ impl VolatileCodeStoreInner {
&mut self,
auth_code: AuthCode,
auth_code_binding: AuthCodeBinding,
expires_at: u64,
expires_at: i64,
) {
self.redeemable_codes
.insert(auth_code.clone(), auth_code_binding);
@ -148,13 +157,15 @@ impl VolatileCodeStoreInner {
}
/// Removes all expired auth codes and returns the time of the earliest next expiry, if present.
pub(self) fn handle_expiry(&mut self, now: u64) -> Option<u64> {
pub(self) fn handle_expiry(&mut self, now: i64) -> Option<i64> {
loop {
let (ts, _auth_code) = self.expire_codes_at.first()?;
// Remove if expired
if *ts <= now {
self.expire_codes_at.pop_first();
let (_, auth_code) = self.expire_codes_at.pop_first().unwrap();
self.redeemable_codes.remove(&auth_code);
self.conflictable_codes.remove(&auth_code);
continue;
}
@ -172,15 +183,16 @@ pub struct VolatileCodeStore {
inner: Arc<Mutex<VolatileCodeStoreInner>>,
}
impl Default for VolatileCodeStore {
fn default() -> Self {
impl VolatileCodeStore {
/// Create a new instance.
pub fn new(clock: Clock) -> Self {
let poke = Arc::new(Notify::new());
let inner: Arc<Mutex<VolatileCodeStoreInner>> = Default::default();
{
let poke = poke.clone();
let inner = inner.clone();
tokio::spawn(Self::expirer(inner, poke));
tokio::spawn(Self::expirer(inner, poke, clock));
}
VolatileCodeStore { inner, poke }
@ -188,19 +200,15 @@ impl Default for VolatileCodeStore {
}
impl VolatileCodeStore {
async fn expirer(inner: Arc<Mutex<VolatileCodeStoreInner>>, poke: Arc<Notify>) {
let mut next_expiry: Option<u64> = None;
async fn expirer(inner: Arc<Mutex<VolatileCodeStoreInner>>, poke: Arc<Notify>, clock: Clock) {
let mut next_expiry: Option<i64> = None;
loop {
match next_expiry {
Some(next_expiry) => {
let sleep_until = UNIX_EPOCH + Duration::from_secs(next_expiry);
let now = SystemTime::now();
let sleep_for = sleep_until
.duration_since(now)
.unwrap_or(Duration::from_secs(60));
let sleep_future = clock.sleep_until(next_expiry);
tokio::select! {
_ = poke.notified() => {},
_ = tokio::time::sleep(sleep_for) => {},
_ = sleep_future => {},
}
}
None => {
@ -208,14 +216,9 @@ impl VolatileCodeStore {
}
}
let now = SystemTime::now();
next_expiry = {
let mut inner = inner.lock().unwrap();
inner.handle_expiry(
now.duration_since(UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_secs(),
)
inner.handle_expiry(clock.now_timestamp())
};
}
}
@ -234,7 +237,7 @@ impl VolatileCodeStore {
}
/// Add a new redeemable authorisation code.
pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: u64) {
pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: i64) {
let mut inner = self.inner.lock().unwrap();
inner.add_redeemable(auth_code, binding, expires_at);
drop(inner);
@ -244,6 +247,7 @@ impl VolatileCodeStore {
}
/// Possible outcomes of an attempt to redeem an authorisation code.
#[derive(PartialEq, Eq, Debug)]
pub enum CodeRedemption {
/// That auth code was not active
Invalid,
@ -264,3 +268,212 @@ pub enum CodeRedemption {
refresh_token_to_invalidate: RefreshTokenHash,
},
}
#[cfg(test)]
mod test {
use std::{str::FromStr, time::Duration};
use assert_matches2::assert_matches;
use rstest::{fixture, rstest};
use sqlx::types::Uuid;
use crate::{utils::Clock, web::oauth_openid::ext_codes::CodeRedemption};
use super::{AuthCode, AuthCodeBinding, VolatileCodeStore, VolatileCodeStoreInner};
const VALID_CODE: AuthCode = AuthCode([21; 24]);
const USER_UUID: Uuid = Uuid::nil();
const USER_LOGIN_SESSION_ID: i32 = 1347;
#[fixture]
fn code_store() -> VolatileCodeStoreInner {
let mut vcs = VolatileCodeStoreInner::default();
vcs.add_redeemable(
VALID_CODE.clone(),
AuthCodeBinding {
client_id: "client_id".to_owned(),
redirect_uri: "https://client/redirect".to_owned(),
nonce: None,
code_challenge_method: "method".to_owned(),
code_challenge: "challenge".to_owned(),
user_id: USER_UUID,
user_login_session_id: USER_LOGIN_SESSION_ID,
},
128,
);
vcs
}
#[rstest]
fn test_redeem_nonexistent_code(mut code_store: VolatileCodeStoreInner) {
let ac = AuthCode([0; 24]);
let result = code_store.redeem(&ac, [42; 32], [43; 32]);
assert_eq!(result, CodeRedemption::Invalid);
}
#[rstest]
fn test_redeem_real_code(mut code_store: VolatileCodeStoreInner) {
let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]);
assert_matches!(result, CodeRedemption::Valid { binding });
assert_eq!(binding.client_id, "client_id");
}
#[rstest]
fn test_redeem_code_twice(mut code_store: VolatileCodeStoreInner) {
// redeem a first time
let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]);
assert_matches!(result, CodeRedemption::Valid { .. });
// redeem a second time
let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]);
assert_matches!(
result,
CodeRedemption::Conflicted {
access_token_to_invalidate,
refresh_token_to_invalidate
}
);
assert_eq!(access_token_to_invalidate, [42; 32]);
assert_eq!(refresh_token_to_invalidate, [43; 32]);
}
#[rstest]
fn test_expire_and_redeem_not_expired_yet(mut code_store: VolatileCodeStoreInner) {
code_store.handle_expiry(127);
// The code shouldn't expire yet.
let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]);
assert_matches!(result, CodeRedemption::Valid { .. });
}
#[rstest]
fn test_expire_and_redeem_expired_now(mut code_store: VolatileCodeStoreInner) {
code_store.handle_expiry(128);
// The code should have expired
let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]);
assert_eq!(result, CodeRedemption::Invalid);
}
#[rstest]
fn test_expire_and_redeem_conflict(mut code_store: VolatileCodeStoreInner) {
// redeem a first time
let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]);
assert_matches!(result, CodeRedemption::Valid { .. });
code_store.handle_expiry(127);
// redeem a second time: the conflict should not have expired yet.
let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]);
assert_matches!(result, CodeRedemption::Conflicted { .. });
// redeem a third time to show the conflict doesn't get removed.
let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]);
assert_matches!(result, CodeRedemption::Conflicted { .. });
code_store.handle_expiry(128);
// The code and its conflict entry should have expired
let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]);
assert_eq!(result, CodeRedemption::Invalid);
}
#[test]
fn test_authcode_display() {
assert_eq!(VALID_CODE.to_string(), "FRUVFRUVFRUVFRUVFRUVFRUVFRUVFRUV");
assert_eq!(
format!("{VALID_CODE:?}"),
"FRUVFRUVFRUVFRUVFRUVFRUVFRUVFRUV"
);
}
#[test]
fn test_authcode_from_string() {
assert_eq!(
AuthCode::from_str(&VALID_CODE.to_string()).unwrap(),
VALID_CODE
);
}
#[test]
fn test_authcode_from_string_wrong_size() {
// Not a valid base64 string
assert_eq!(AuthCode::from_str("a").unwrap_err(), "wrong size");
// Not 24 bytes when decoded
assert_eq!(AuthCode::from_str("abcd").unwrap_err(), "wrong size");
}
/// Basic test for the full-fat VolatileCodeStore, not just its inner logic
/// We can't easily cover everything here but may as well test the basics.
#[tokio::test]
async fn test_fullfat_store_basic() {
let clock = Clock::new_test();
let vcs = VolatileCodeStore::new(clock.clone());
vcs.add_redeemable(
VALID_CODE.clone(),
AuthCodeBinding {
client_id: "client_id".to_owned(),
redirect_uri: "https://client/redirect".to_owned(),
nonce: None,
code_challenge_method: "method".to_owned(),
code_challenge: "challenge".to_owned(),
user_id: USER_UUID,
user_login_session_id: USER_LOGIN_SESSION_ID,
},
i64::MAX,
);
assert_matches!(
vcs.redeem(&VALID_CODE, [1; 32], [2; 32]),
CodeRedemption::Valid { .. }
);
vcs.add_redeemable(
VALID_CODE.clone(),
AuthCodeBinding {
client_id: "client_id".to_owned(),
redirect_uri: "https://client/redirect".to_owned(),
nonce: None,
code_challenge_method: "method".to_owned(),
code_challenge: "challenge".to_owned(),
user_id: USER_UUID,
user_login_session_id: USER_LOGIN_SESSION_ID,
},
// Expiry in the past: this should be auto-expired when poked and
// given a moment
1,
);
clock.set_time(2);
vcs.add_redeemable(
AuthCode([0; 24]),
AuthCodeBinding {
client_id: "client_id".to_owned(),
redirect_uri: "https://client/redirect".to_owned(),
nonce: None,
code_challenge_method: "method".to_owned(),
code_challenge: "challenge".to_owned(),
user_id: USER_UUID,
user_login_session_id: USER_LOGIN_SESSION_ID,
},
i64::MAX,
);
// Give a short time for the expiry to take place.
tokio::time::sleep(Duration::from_millis(2)).await;
assert_matches!(
vcs.redeem(&VALID_CODE, [1; 32], [2; 32]),
CodeRedemption::Invalid
);
}
}

View File

@ -1,10 +1,6 @@
//! `/oidc/token`
use std::{
str::FromStr,
sync::Arc,
time::{SystemTime, UNIX_EPOCH},
};
use std::{str::FromStr, sync::Arc};
use axum::{
extract::rejection::FormRejection,
@ -15,13 +11,13 @@ use axum::{
};
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
use blake2::Blake2s256;
use chrono::{Duration, Utc};
use chrono::Duration;
use eyre::{bail, Context};
use josekit::{
jws::{alg::rsassa::RsassaJwsAlgorithm::Rs256, JwsHeader},
jwt::JwtPayload,
};
use rand::{thread_rng, Rng};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
@ -30,6 +26,7 @@ use tracing::{debug, error};
use crate::{
config::{Configuration, SecretConfig},
store::IdCoopStore,
utils::{Clock, RandGen},
};
use super::ext_codes::{
@ -53,12 +50,15 @@ pub struct TokenFormParams {
/// OpenID Connect clients call this to exchange an authorisation code they received for an access token.
///
/// TODO auth_header can be one alternative auth method
#[allow(clippy::too_many_arguments)]
pub async fn oidc_token(
basic_auth: Option<TypedHeader<Authorization<Basic>>>,
Extension(config): Extension<Arc<Configuration>>,
Extension(secrets): Extension<Arc<SecretConfig>>,
Extension(store): Extension<Arc<IdCoopStore>>,
Extension(code_store): Extension<VolatileCodeStore>,
Extension(mut randgen): Extension<RandGen>,
Extension(clock): Extension<Clock>,
form: Result<Form<TokenFormParams>, FormRejection>,
) -> impl IntoResponse {
let form = match form {
@ -110,8 +110,9 @@ pub async fn oidc_token(
Json(TokenError {
code: TokenErrorCode::InvalidClient,
description: "That `client_id` is not recognised here.".to_string(),
})
).into_response();
}),
)
.into_response();
};
if !bool::from(
@ -181,10 +182,10 @@ pub async fn oidc_token(
// Create an access token but don't actually issue it yet:
// This lets us store the hash of the access token against the redemption of the auth code,
// so double redemptions can invalidate the access token appropriately.
let access_token = thread_rng().gen::<AccessToken>();
let access_token = randgen.gen::<AccessToken>();
let access_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(access_token);
let access_token_hash: AccessTokenHash = Blake2s256::digest(access_token).into();
let refresh_token = thread_rng().gen::<RefreshToken>();
let refresh_token = randgen.gen::<RefreshToken>();
let refresh_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(refresh_token);
let refresh_token_hash: RefreshTokenHash = Blake2s256::digest(refresh_token).into();
@ -256,7 +257,9 @@ pub async fn oidc_token(
}
// 2. Check the code challenge
let Some(computed_code_challenge) = compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier) else {
let Some(computed_code_challenge) =
compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier)
else {
return (
StatusCode::BAD_REQUEST,
Json(TokenError {
@ -287,28 +290,29 @@ pub async fn oidc_token(
let user_login_session_id = binding.user_login_session_id;
// Issue access token for a new session
let clock2 = clock.clone();
match store
.txn(move |mut txn| {
Box::pin(async move {
let Some(session_id) = txn
.create_application_session(user_id, &application_id, user_login_session_id)
.await
.context("create_application_session")? else {
return Ok(Err(
(
StatusCode::BAD_REQUEST,
Json(TokenError {
code: TokenErrorCode::InvalidGrant,
description: "Auth code has expired or was not valid.".to_owned(),
}),
).into_response()
));
.context("create_application_session")?
else {
return Ok(Err((
StatusCode::BAD_REQUEST,
Json(TokenError {
code: TokenErrorCode::InvalidGrant,
description: "Auth code has expired or was not valid.".to_owned(),
}),
)
.into_response()));
};
txn.issue_access_token(
&access_token_hash,
session_id,
// TODO(expiry) Support custom expiry, not 100 years
Utc::now() + Duration::days(365 * 100),
clock2.now_utc() + Duration::days(365 * 100),
)
.await
.context("issue_access_token")?;
@ -316,7 +320,7 @@ pub async fn oidc_token(
&refresh_token_hash,
session_id,
// TODO(expiry) Support custom expiry, not 100 years
Utc::now() + Duration::days(365 * 100),
clock2.now_utc() + Duration::days(365 * 100),
)
.await
.context("issue_refresh_token")?;
@ -344,10 +348,7 @@ pub async fn oidc_token(
}
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Before unix epoch?")
.as_secs();
let now = clock.now_timestamp();
// TODO(expiry) Support custom expiry times (not just 100 years)
let exp = now + 100 * 365 * 86400;
let sub = binding.user_id.hyphenated().to_string();
@ -385,7 +386,9 @@ pub async fn oidc_token(
}
fn make_id_token(id_token: IdToken, secrets: &SecretConfig) -> eyre::Result<String> {
let Ok(serde_json::Value::Object(map)) = serde_json::to_value(id_token).context("failed to serialise ID Token content") else {
let Ok(serde_json::Value::Object(map)) =
serde_json::to_value(id_token).context("failed to serialise ID Token content")
else {
bail!("ID Token not a map");
};
@ -447,9 +450,9 @@ pub struct IdToken {
/// Implementers MAY provide for some small leeway, usually no more than a few minutes, to account for clock skew.
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
/// See RFC 3339 [RFC3339] for details regarding date/times in general and UTC in particular.
pub exp: u64,
pub exp: i64,
/// REQUIRED. Time at which the JWT was issued. Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
pub iat: u64,
pub iat: i64,
/// Time when the End-User authentication occurred.
/// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time.
/// When a max_age request is made or when auth_time is requested as an Essential Claim, then this Claim is REQUIRED; otherwise, its inclusion is OPTIONAL.

View File

@ -7,20 +7,22 @@
//! Even on older browsers, that type of attack is rare, so this 'naïve' scheme is fine for the purpose (rather than having to do anything complicated).
use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine};
use rand::{thread_rng, Rng};
use rand::Rng;
use subtle::ConstantTimeEq;
use time::Duration;
use tower_cookies::{Cookie, Cookies};
use crate::utils::RandGen;
/// Name of the cookie for the sessionless cross-site request forgery prevention cookie.
pub const COOKIE_NAME: &str = "__Host-SessionlessXsrf";
/// Gets the Sessionless XSRF token to put into a form request
pub fn get_token(cookies: &Cookies) -> String {
pub fn get_token(cookies: &Cookies, randgen: &mut RandGen) -> String {
if let Some(xsrf_cookie) = cookies.get(COOKIE_NAME) {
xsrf_cookie.value().to_owned()
} else {
let new_token = thread_rng().gen::<[u8; 16]>();
let new_token = randgen.gen::<[u8; 16]>();
let new_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(new_token);
cookies.add(
Cookie::build(COOKIE_NAME, new_token_b64.clone())
@ -37,7 +39,9 @@ pub fn get_token(cookies: &Cookies) -> String {
/// Checks a Sessionless XSRF token obtained from a form request
pub fn check_token(cookies: &Cookies, token: &str) -> bool {
let Some(xsrf_token) = cookies.get(COOKIE_NAME) else { return false };
let Some(xsrf_token) = cookies.get(COOKIE_NAME) else {
return false;
};
bool::from(xsrf_token.value().as_bytes().ct_eq(token.as_bytes()))
}