diff --git a/Cargo.lock b/Cargo.lock index 32feffd..12095d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -152,6 +152,12 @@ dependencies = [ "password-hash", ] +[[package]] +name = "assert_matches2" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15832d94c458da98cac0ffa6eca52cc19c2a3c6c951058500a5ae8f01f0fdf56" + [[package]] name = "async-recursion" version = "1.0.5" @@ -265,6 +271,24 @@ dependencies = [ "syn 2.0.38", ] +[[package]] +name = "axum-test-helper" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "298f62fa902c2515c169ab0bfb56c593229f33faa01131215d58e3d4898e3aa9" +dependencies = [ + "axum", + "bytes", + "http", + "http-body", + "hyper", + "reqwest", + "serde", + "tokio", + "tower", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -301,7 +325,7 @@ dependencies = [ "quote", "rustc-hash", "syn 2.0.38", - "toml_edit", + "toml_edit 0.19.15", ] [[package]] @@ -595,6 +619,18 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "const-oid" version = "0.9.5" @@ -854,6 +890,21 @@ dependencies = [ "serde", ] +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -1093,9 +1144,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] @@ -1291,6 +1342,25 @@ dependencies = [ "smallvec", ] +[[package]] +name = "h2" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap 2.0.2", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1527,6 +1597,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", + "h2", "http", "http-body", "httparse", @@ -1568,9 +1639,11 @@ name = "idcoop" version = "0.0.1" dependencies = [ "argon2", + "assert_matches2", "async-trait", "axum", "axum-client-ip", + "axum-test-helper", "base64", "blake2", "chrono", @@ -1581,11 +1654,16 @@ dependencies = [ "futures", "governor", "hornbeam", + "insta", "josekit", + "maplit", "metrics", "metrics-exporter-prometheus", "metrics-process", + "pgtemp", "rand", + "rand_xoshiro", + "rstest", "serde", "serde_json", "serde_urlencoded", @@ -1602,9 +1680,9 @@ dependencies = [ [[package]] name = "idna" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -1673,6 +1751,19 @@ dependencies = [ "libc", ] +[[package]] +name = "insta" +version = "1.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "810ae6042d48e2c9e9215043563a58a80b877bc863228a74cf10c49d4620a6f5" +dependencies = [ + "console", + "lazy_static", + "linked-hash-map", + "serde", + "similar", +] + [[package]] name = "instant" version = "0.1.12" @@ -1851,6 +1942,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.1.4" @@ -1888,6 +1985,12 @@ dependencies = [ "libc", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "matchers" version = "0.1.0" @@ -2004,6 +2107,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2285,9 +2398,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" @@ -2366,6 +2479,18 @@ dependencies = [ "indexmap 2.0.2", ] +[[package]] +name = "pgtemp" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc39977b03503cfcf74ce2f8938d7b8340d1f8ef31a3f2a289b1c1abff8ace2c" +dependencies = [ + "libc", + "tempfile", + "tokio", + "url", +] + [[package]] name = "pin-project" version = "1.1.3" @@ -2449,6 +2574,15 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "proc-macro-crate" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284" +dependencies = [ + "toml_edit 0.21.1", +] + [[package]] name = "proc-macro-hack" version = "0.5.20+deprecated" @@ -2532,6 +2666,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core", +] + [[package]] name = "raw-cpuid" version = "10.7.0" @@ -2603,6 +2746,51 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3cbb081b9784b07cceb8824c8583f86db4814d172ab043f3c23f7dc600bf83d" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + +[[package]] +name = "reqwest" +version = "0.11.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "ipnet", + "js-sys", + "log", + "mime", + "mime_guess", + "once_cell", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "system-configuration", + "tokio", + "tokio-util", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "winreg", +] + [[package]] name = "rlimit" version = "0.10.1" @@ -2632,6 +2820,36 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rstest" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afd55a67069d6e434a95161415f5beeada95a01c7b815508a82dcb0e1593682" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4165dfae59a39dd41d8dec720d3cbfbc71f69744efb480a3920f5d4e0cc6798d" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.38", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -2644,6 +2862,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.36.15" @@ -2736,6 +2963,12 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ef965a420fe14fdac7dd018862966a4c14094f900e1650bbc71ddd7d580c8af" +[[package]] +name = "semver" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" + [[package]] name = "serde" version = "1.0.188" @@ -2827,6 +3060,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7cee0529a6d40f580e7a5e6c495c8fbfe21b7b52795ed4bb5e62cdf92bc6380" +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.1.0" @@ -2837,6 +3079,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "similar" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa42c91313f1d05da9b26f267f931cf178d4aba455b4c4622dd7355eb80c6640" + [[package]] name = "sketches-ddsketch" version = "0.2.1" @@ -3210,6 +3458,27 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.8.1" @@ -3325,7 +3594,10 @@ dependencies = [ "bytes", "libc", "mio", + "num_cpus", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2 0.5.4", "tokio-macros", "windows-sys 0.48.0", @@ -3353,6 +3625,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.5.11" @@ -3379,6 +3664,17 @@ dependencies = [ "winnow", ] +[[package]] +name = "toml_edit" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" +dependencies = [ + "indexmap 2.0.2", + "toml_datetime", + "winnow", +] + [[package]] name = "tower" version = "0.4.13" @@ -3576,6 +3872,15 @@ dependencies = [ "unic-langid-impl", ] +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.13" @@ -3617,9 +3922,9 @@ checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" [[package]] name = "url" -version = "2.4.1" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna", @@ -3716,6 +4021,18 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.87" @@ -3745,6 +4062,19 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wasm-streams" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.64" @@ -4025,6 +4355,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + [[package]] name = "zeroize" version = "1.6.0" diff --git a/Cargo.toml b/Cargo.toml index 67f2111..c124ee9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ metrics = "0.21.1" metrics-exporter-prometheus = "0.12.1" metrics-process = "1.0.12" rand = "0.8.5" +rand_xoshiro = "0.6.0" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1.0.108" serde_urlencoded = "0.7.1" @@ -39,3 +40,12 @@ tower-cookies = "0.9.0" tower-http = { version = "0.4.4", features = ["trace", "cors", "set-header"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } + +[dev-dependencies] +assert_matches2 = "0.1.2" +axum-test-helper = "0.3.0" +insta = { version = "1.39.0", features = ["serde", "yaml"] } +maplit = "1.0.2" +pgtemp = "0.3.0" +rand_xoshiro = "0.6.0" +rstest = "0.21.0" diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 23f6caf..671305a 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -10,4 +10,4 @@ # Development - +- [Testing](dev/testing.md) diff --git a/docs/dev/testing.md b/docs/dev/testing.md new file mode 100644 index 0000000..e6aa61a --- /dev/null +++ b/docs/dev/testing.md @@ -0,0 +1,45 @@ +# Testing + +## Testing approach + +### Unit tests + +Unit tests should live inside a module in the code unit they are testing, +gated behind `#[cfg(test)]`. \ +This is fairly common Rust practice. + +There is no hard and fast rule for the granularity of the unit tests, +but they should test the smallest amount of logic that is simple to test, +but no smaller. \ +In practice this means that unit tests should be at the function-level +where this makes sense, or at the struct-level if this makes more sense. + +For now, avoid the use of test mocks, but use them if it makes +strong sense to do so. + + +### Integration tests + +Integration tests should live in the `tests/` directory. + +In general, each test will get its own throwaway Postgres database. + + +#### Snapshot tests + +Some of the integration tests will compare snapshots (of e.g. HTML) against +a gold standard. + +When a new snapshot is created, the output should be manually verified, +including in a browser if necessary. + +It goes without saying that all snapshot changes should be expected; +if they are not then treated as failures. + + +### End-to-end tests + +idCoop doesn't currently have end-to-end tests but this is on the wishlist +for the future. + +Will eventually look into Playwright etc. diff --git a/flake-devenv/flake.nix b/flake-devenv/flake.nix index eb918ed..e2b4f9d 100644 --- a/flake-devenv/flake.nix +++ b/flake-devenv/flake.nix @@ -74,6 +74,9 @@ # Test coverage. Vaguely useful but not definitive. pkgs.cargo-tarpaulin + # Snapshot testing + pkgs.cargo-insta + pkgs.grass-sass pkgs.entr diff --git a/src/bin/idcoop.rs b/src/bin/idcoop.rs index cbdfc4a..d02a75d 100644 --- a/src/bin/idcoop.rs +++ b/src/bin/idcoop.rs @@ -1,15 +1,12 @@ -use std::io::stdin; use std::sync::Arc; use std::{net::SocketAddr, path::PathBuf}; use clap::Parser; -use comfy_table::presets::UTF8_FULL; -use comfy_table::{Attribute, Cell, Color, ContentArrangement, Row, Table}; use confique::{Config, Partial}; -use eyre::{bail, Context, ContextCompat}; +use eyre::{bail, Context}; +use idcoop::cli::{handle_user_command, UserCommand}; use idcoop::config::{SecretConfig, SeparateSecretConfiguration}; -use idcoop::passwords::create_password_hash; -use idcoop::store::{CreateUser, IdCoopStore}; +use idcoop::store::IdCoopStore; use idcoop::{config::Configuration, web}; use tracing_subscriber::fmt::format::FmtSpan; use tracing_subscriber::layer::SubscriberExt; @@ -106,180 +103,15 @@ async fn main() -> eyre::Result<()> { .await .context("Failed to connect to Postgres")?; let secrets = SecretConfig::try_new(&config).await?; - web::serve(bind, Arc::new(store), Arc::new(config), Arc::new(secrets)).await? + web::serve(bind, Arc::new(store), Arc::new(config), Arc::new(secrets)).await?; } - Subcommand::User { cmd } => handle_user_command(cmd, &config).await?, - } - - Ok(()) -} - -/// Commands for user management. -#[derive(Clone, Parser)] -enum UserCommand { - /// Add a user. - #[clap(alias = "new", alias = "create")] - Add { - /// The login name of the user. - // TODO this should be a richer newtype with validation - username: String, - - #[clap(long = "locked")] - locked: bool, - }, - /// Deletes a user. - /// Consider whether this is what you really want: in most cases locking a user is more appropriate. - #[clap(alias = "remove", alias = "rm", alias = "del")] - Delete { - /// The login name of the user. - username: String, - }, - /// Locks a user, preventing them from logging in. - Lock { - /// The login name of the user. - username: String, - }, - /// Unlocks a user, letting them log in once more. - Unlock { - /// The login name of the user. - username: String, - }, - /// Changes a user's password. - #[clap(alias = "chpass", alias = "passwd")] - ChangePassword { - /// The login name of the user. - username: String, - }, - /// Lists all users that are registered. - #[clap(alias = "ls")] - ListAll { - /// Only show a list of usernames, without table formatting characters and one per line. May be useful in scripts. - #[clap(long = "usernames")] - usernames: bool, - }, -} - -async fn handle_user_command(command: UserCommand, config: &Configuration) -> eyre::Result<()> { - let store = IdCoopStore::connect(&config.postgres.connect) - .await - .context("Failed to connect to Postgres")?; - match command { - UserCommand::Add { username, locked } => { - store - .txn(|mut txn| { - Box::pin(async move { - txn.create_user(CreateUser { - user_login_name: username, - password_hash: None, - locked, - }) - .await - }) - }) + Subcommand::User { cmd } => { + let store = IdCoopStore::connect(&config.postgres.connect) .await - .context("failed to add user")?; - } - UserCommand::Delete { username } => { - store - .txn(|mut txn| { - Box::pin(async move { - let user_id = txn - .lookup_user_by_name(username) - .await? - .context("No user by that name")? - .user_id; - txn.delete_user(user_id).await - }) - }) - .await?; - } - UserCommand::Lock { username } => { - store - .txn(|mut txn| { - Box::pin(async move { - let user_id = txn - .lookup_user_by_name(username) - .await? - .context("No user by that name")? - .user_id; - txn.set_user_locked(user_id, true).await - }) - }) - .await?; - } - UserCommand::Unlock { username } => { - store - .txn(|mut txn| { - Box::pin(async move { - let user_id = txn - .lookup_user_by_name(username) - .await? - .context("No user by that name")? - .user_id; - txn.set_user_locked(user_id, false).await - }) - }) - .await?; - } - UserCommand::ChangePassword { username } => { - let Some(user) = store.txn(|mut txn| { Box::pin(async move { - txn.lookup_user_by_name(username).await - })}).await? else { - bail!("No user by that name."); - }; - println!("Change password for {} ({}):", user.user_name, user.user_id); - let mut buf_line = String::new(); - stdin() - .read_line(&mut buf_line) - .context("failed to read password")?; - let raw_password = buf_line.trim(); - let hash = create_password_hash(raw_password, &config.password_hashing) - .context("unable to hash password!")?; - store - .txn(|mut txn| { - Box::pin( - async move { txn.change_user_password(user.user_id, Some(hash)).await }, - ) - }) - .await?; - } - UserCommand::ListAll { usernames } => { - let user_infos = store - .txn(|mut txn| Box::pin(async move { txn.list_user_info().await })) - .await?; - - if usernames { - for user_info in user_infos { - println!("{}", user_info.user_name); - } - } else { - let mut table = Table::new(); - table - .load_preset(UTF8_FULL) - .set_content_arrangement(ContentArrangement::Dynamic) - .set_width(80) - .set_header(vec![ - Cell::new("Name").add_attribute(Attribute::Bold), - Cell::new("UUID").add_attribute(Attribute::Bold), - Cell::new("Locked").add_attribute(Attribute::Bold), - ]); - - for user_info in user_infos { - let mut row = Row::new(); - row.add_cell(Cell::new(user_info.user_name)); - row.add_cell(Cell::new(user_info.user_id).fg(Color::Grey)); - let (lock_str, lock_colour) = if user_info.locked { - ("yes", Color::Red) - } else { - ("no", Color::White) - }; - row.add_cell(Cell::new(lock_str).fg(lock_colour)); - table.add_row(row); - } - - println!("{}", table); - } + .context("Failed to connect to Postgres")?; + handle_user_command(cmd, &config, &store).await?; } } + Ok(()) } diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..9d460c3 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,185 @@ +//! idCoop Command Line Interface + +use std::io::stdin; + +use crate::config::Configuration; +use crate::passwords::create_password_hash; +use crate::store::{CreateUser, IdCoopStore}; +use clap::Parser; +use comfy_table::presets::UTF8_FULL; +use comfy_table::{Attribute, Cell, Color, ContentArrangement, Row, Table}; +use eyre::{bail, Context, ContextCompat}; + +/// Commands for user management. +#[derive(Clone, Parser)] +pub enum UserCommand { + /// Add a user. + #[clap(alias = "new", alias = "create")] + Add { + /// The login name of the user. + // TODO this should be a richer newtype with validation + username: String, + + /// Set this flag if the user should be locked. + #[clap(long = "locked")] + locked: bool, + }, + /// Deletes a user. + /// Consider whether this is what you really want: in most cases locking a user is more appropriate. + #[clap(alias = "remove", alias = "rm", alias = "del")] + Delete { + /// The login name of the user. + username: String, + }, + /// Locks a user, preventing them from logging in. + Lock { + /// The login name of the user. + username: String, + }, + /// Unlocks a user, letting them log in once more. + Unlock { + /// The login name of the user. + username: String, + }, + /// Changes a user's password. + #[clap(alias = "chpass", alias = "passwd")] + ChangePassword { + /// The login name of the user. + username: String, + }, + /// Lists all users that are registered. + #[clap(alias = "ls")] + ListAll { + /// Only show a list of usernames, without table formatting characters and one per line. May be useful in scripts. + #[clap(long = "usernames")] + usernames: bool, + }, +} + +/// Handles a user command from the command-line interface. +pub async fn handle_user_command( + command: UserCommand, + config: &Configuration, + store: &IdCoopStore, +) -> eyre::Result<()> { + match command { + UserCommand::Add { username, locked } => { + store + .txn(|mut txn| { + Box::pin(async move { + txn.create_user(CreateUser { + user_login_name: username, + password_hash: None, + locked, + }) + .await + }) + }) + .await + .context("failed to add user")?; + } + UserCommand::Delete { username } => { + store + .txn(|mut txn| { + Box::pin(async move { + let user_id = txn + .lookup_user_by_name(username) + .await? + .context("No user by that name")? + .user_id; + txn.delete_user(user_id).await + }) + }) + .await?; + } + UserCommand::Lock { username } => { + store + .txn(|mut txn| { + Box::pin(async move { + let user_id = txn + .lookup_user_by_name(username) + .await? + .context("No user by that name")? + .user_id; + txn.set_user_locked(user_id, true).await + }) + }) + .await?; + } + UserCommand::Unlock { username } => { + store + .txn(|mut txn| { + Box::pin(async move { + let user_id = txn + .lookup_user_by_name(username) + .await? + .context("No user by that name")? + .user_id; + txn.set_user_locked(user_id, false).await + }) + }) + .await?; + } + UserCommand::ChangePassword { username } => { + let Some(user) = store + .txn(|mut txn| Box::pin(async move { txn.lookup_user_by_name(username).await })) + .await? + else { + bail!("No user by that name."); + }; + println!("Change password for {} ({}):", user.user_name, user.user_id); + let mut buf_line = String::new(); + stdin() + .read_line(&mut buf_line) + .context("failed to read password")?; + let raw_password = buf_line.trim(); + let hash = create_password_hash(raw_password, &config.password_hashing) + .context("unable to hash password!")?; + store + .txn(|mut txn| { + Box::pin( + async move { txn.change_user_password(user.user_id, Some(hash)).await }, + ) + }) + .await?; + } + UserCommand::ListAll { usernames } => { + let user_infos = store + .txn(|mut txn| Box::pin(async move { txn.list_user_info().await })) + .await?; + + if usernames { + for user_info in user_infos { + println!("{}", user_info.user_name); + } + } else { + let mut table = Table::new(); + table + .load_preset(UTF8_FULL) + .set_content_arrangement(ContentArrangement::Dynamic) + .set_width(80) + .set_header(vec![ + Cell::new("Name").add_attribute(Attribute::Bold), + Cell::new("UUID").add_attribute(Attribute::Bold), + Cell::new("Locked").add_attribute(Attribute::Bold), + ]); + + for user_info in user_infos { + let mut row = Row::new(); + row.add_cell(Cell::new(user_info.user_name)); + row.add_cell(Cell::new(user_info.user_id).fg(Color::Grey)); + let (lock_str, lock_colour) = if user_info.locked { + ("yes", Color::Red) + } else { + ("no", Color::White) + }; + row.add_cell(Cell::new(lock_str).fg(lock_colour)); + table.add_row(row); + } + + println!("{}", table); + } + } + } + Ok(()) +} diff --git a/src/config.rs b/src/config.rs index d237d70..6d36929 100644 --- a/src/config.rs +++ b/src/config.rs @@ -184,6 +184,7 @@ pub struct RatelimitsConfig { /// - "5 Hz, 20 burst" /// - "5 per second, 10 burst" /// - "10 per hour, 5 burst" +#[derive(Debug)] pub struct RatelimiterConfig { /// The inner [`Quota`], which this struct is just a wrapper for. pub quota: Quota, @@ -196,13 +197,19 @@ impl<'de> Deserialize<'de> for RatelimiterConfig { { let stringy_format = String::deserialize(deserializer)?; let Some((left, right)) = stringy_format.split_once(',') else { - return Err(D::Error::custom("no comma. ratelimiter string should be like '5 Hz, 20 burst'")); + return Err(D::Error::custom( + "no comma. ratelimiter string should be like '5 Hz, 20 burst'", + )); }; let Some((left_val, left_unit)) = left.trim().split_once(' ') else { - return Err(D::Error::custom("no units on left. ratelimiter string should be like '5 Hz, 20 burst'")); + return Err(D::Error::custom( + "no units on left. ratelimiter string should be like '5 Hz, 20 burst'", + )); }; let Some((right_val, right_unit)) = right.trim().split_once(' ') else { - return Err(D::Error::custom("no units on right. ratelimiter string should be like '5 Hz, 20 burst'")); + return Err(D::Error::custom( + "no units on right. ratelimiter string should be like '5 Hz, 20 burst'", + )); }; let Ok(left_val) = left_val.parse::() else { @@ -237,3 +244,62 @@ impl<'de> Deserialize<'de> for RatelimiterConfig { Ok(RatelimiterConfig { quota }) } } + +#[cfg(test)] +mod test { + use crate::config::RatelimiterConfig; + + #[test] + fn test_ratelimiter_deser_errors() { + fn deser(s: &str) -> Result { + serde_json::from_value(serde_json::Value::String(s.to_owned())) + } + + // this is fine + deser("5 Hz, 20 burst").unwrap(); + + // no comma + assert_eq!( + deser("5 Hz 20 burst").unwrap_err().to_string(), + "no comma. ratelimiter string should be like '5 Hz, 20 burst'" + ); + + // bad numbers + assert_eq!( + deser("five Hz, 20 burst").unwrap_err().to_string(), + "bad value on left. ratelimiter string should be like '5 Hz, 20 burst'" + ); + assert_eq!( + deser("5 Hz, twenty burst").unwrap_err().to_string(), + "bad value on right. ratelimiter string should be like '5 Hz, 20 burst'" + ); + + // wrong order + assert_eq!( + deser("20 burst, 5 Hz").unwrap_err().to_string(), + "bad units on left. ratelimiter string should be like '5 Hz, 20 burst' or '5 per hour, 20 burst'." + ); + + // no units + assert_eq!( + deser("5 Hz, 20").unwrap_err().to_string(), + "no units on right. ratelimiter string should be like '5 Hz, 20 burst'" + ); + assert_eq!( + deser("5, 20 burst").unwrap_err().to_string(), + "no units on left. ratelimiter string should be like '5 Hz, 20 burst'" + ); + + // bad units + assert_eq!( + deser("20 per milleniumm, 20 burst") + .unwrap_err() + .to_string(), + "bad units on left. ratelimiter string should be like '5 Hz, 20 burst' or '5 per hour, 20 burst'." + ); + assert_eq!( + deser("5 Hz, 20 wombats").unwrap_err().to_string(), + "bad units on right. ratelimiter string should be like '5 Hz, 20 burst' or '5 per hour, 20 burst'." + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index c888d47..dcf0dfd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,12 @@ #![deny(missing_docs)] +pub mod cli; pub mod config; pub mod passwords; pub mod store; +pub mod utils; pub mod web; + +#[cfg(test)] +mod tests; diff --git a/src/passwords.rs b/src/passwords.rs index 89e8687..c63cb9a 100644 --- a/src/passwords.rs +++ b/src/passwords.rs @@ -52,3 +52,43 @@ pub fn check_hash(password: &str, hash: &str) -> eyre::Result { Ok(argon2.verify_password(password.as_bytes(), &hash).is_ok()) } + +#[cfg(test)] +mod test { + use rstest::{fixture, rstest}; + + use crate::{config::PasswordHashingConfig, passwords::check_hash}; + + use super::create_password_hash; + + /// Password hash for "secret" + const EXAMPLE_SECRET_PASSWORD_HASH: &str = "$argon2id$v=19$m=512,t=1,p=1$Z11PjkMSx/rm4IbDzDmK2Q$VtUH6Iee/GD1FltULyLf6/QRwjNA9d5+mjAKS+WzlZw"; + + /// Password hash for "verysecret" + const EXAMPLE_VERYSECRET_PASSWORD_HASH: &str = "$argon2id$v=19$m=512,t=1,p=1$tcGnJC8EOR4B43z35roeFg$o3grpVK9HAb3850iRI1/nXR+nPZbyFWchmkbkBhm7Co"; + + #[fixture] + fn password_config() -> PasswordHashingConfig { + PasswordHashingConfig { + memory: 512, + iterations: 1, + parallelism: 1, + } + } + + #[rstest] + fn test_valid_password(password_config: PasswordHashingConfig) { + let pwh = create_password_hash("secret", &password_config).unwrap(); + + assert!(check_hash("secret", &pwh).unwrap()); + assert!(check_hash("secret", EXAMPLE_SECRET_PASSWORD_HASH).unwrap()); + } + + #[rstest] + fn test_invalid_password(password_config: PasswordHashingConfig) { + let pwh = create_password_hash("verysecret", &password_config).unwrap(); + + assert!(!check_hash("secret", &pwh).unwrap()); + assert!(!check_hash("secret", EXAMPLE_VERYSECRET_PASSWORD_HASH).unwrap()); + } +} diff --git a/src/store.rs b/src/store.rs index 0ae2db3..43b86b6 100644 --- a/src/store.rs +++ b/src/store.rs @@ -110,7 +110,7 @@ pub struct UserInfo { /// A wrapper around a database transaction with some database methods on it. pub struct IdCoopStoreTxn<'a, 'txn> { - txn: &'a mut Transaction<'txn, Postgres>, + pub(crate) txn: &'a mut Transaction<'txn, Postgres>, } impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> { @@ -341,7 +341,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> { .await .context("failed to lookup login session")?; - let Some(row) = row_opt else { return Ok(None); }; + let Some(row) = row_opt else { + return Ok(None); + }; Ok(Some(LoginSession { user_name: row.user_name, @@ -373,7 +375,9 @@ impl<'a, 'txn> IdCoopStoreTxn<'a, 'txn> { .await .context("failed to lookup application session")?; - let Some(row) = row_opt else { return Ok(None); }; + let Some(row) = row_opt else { + return Ok(None); + }; Ok(Some(ApplicationSession { application_session_id: row.session_id, diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..8bb3443 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use axum::Router; +use confique::{Config, Partial}; +use josekit::jwk::alg::rsa::RsaKeyPair; +use metrics::atomics::AtomicU64; +use pgtemp::PgTempDB; +use serde_json::json; + +use crate::{ + config::{Configuration, SecretConfig}, + store::IdCoopStore, + utils::{Clock, RandGen}, + web::make_router, +}; + +struct TestSystem { + database: PgTempDB, + web: Router, + config: Arc, + store: Arc, + clock: Clock, +} + +const RSA_KEY_PAIR_PEM: &[u8] = include_bytes!("tests/keypair.pem"); +const RSA_PUBLIC_KEY_PEM: &[u8] = include_bytes!("tests/publickey.crt"); + +mod test_cli; +mod test_oidc_auth_flow; + +async fn basic_system() -> TestSystem { + let temp_db = pgtemp::PgTempDBBuilder::new() + .with_dbname("test_idcoop") + .start_async() + .await; + + let store = IdCoopStore::connect(&temp_db.connection_uri()) + .await + .expect("failed to connect to pgtemp db"); + + let config_partial: ::Partial = + serde_json::from_value(json!({ + "listen": { + // Not useful, not actually used in the tests + "bind": "127.0.0.1:1", + "public_base_uri": "http://idcoop.example.com", + "client_ip_source": "RightmostXForwardedFor", + }, + "postgres": { + "connect": "postgres://not-used-in-tests" + }, + "oidc": { + "issuer": "http://issuer.example.com", + "rsa_keypair": "not-used-in-tests", + "clients": { + "aclient": { + "redirect_uris": [ + "http://aclient.example.com/redirect" + ], + "name": "AClient", + "allow_user_classes": ["active"], + "secret": "secretA", + } + } + }, + "password_hashing": { + // Use weak password hash settings; we're not testing Argon2 here, + // we just want it to be fast. + "memory": 512, + "iterations": 1, + }, + "ratelimits": { + "login": "3 per hour, 2 burst", + }, + })) + .expect("bad test config"); + let config = + Configuration::from_partial(config_partial.with_fallback(Partial::default_values())) + .expect("failed to load builtin config"); + + let secrets = SecretConfig { + rsa_key_pair: RsaKeyPair::from_pem(RSA_KEY_PAIR_PEM) + .expect("failed to decode builtin RSA keypair"), + }; + + let config = Arc::new(config); + let store = Arc::new(store); + let clock = Clock(Arc::new(AtomicU64::new(0))); + let randgen = RandGen::new(); + let router = make_router( + store.clone(), + config.clone(), + Arc::new(secrets), + clock.clone(), + randgen, + ) + .await + .expect("failed to make router"); + + TestSystem { + database: temp_db, + web: router, + config, + store, + clock, + } +} + +#[tokio::test] +async fn test_demo() { + let basic = basic_system().await; +} diff --git a/src/tests/keypair.pem b/src/tests/keypair.pem new file mode 100644 index 0000000..c6719a7 --- /dev/null +++ b/src/tests/keypair.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDDu6acOa+3ae2S +0llp5oMsXjBMd5QJeQJCcY5Q9NAITF2U9VBwAiMf2wmaTZ1aWWFGSb/zWef7Hx1e +qNhsK9MYL+QdJih2I+KpMtDWm7hhy9FtCHVc9i1z9PruXb0om2jDWLuBkPdCqJZT +C58ObZKmgL4OH5F1Qv5JR/ZX21OjLolXPJo1sonLv9mlgufvhUmnC17onSSqFLBA +nhUedjbfdnLShkp0xa8G0nW7Ls7Idaxyo8M5S2M+azJyLI87eqjjfz0yIW0Am890 +mRa81hO2D+YNcVA2wIE9MEI/ie480YxLQ0VHCjX4DcVir4ceExysYkdL8VK+U14g +NO67k4NVAgMBAAECggEABnseJyoZ0V7miOgCIemKClwMCVwkQLQLCRwtdCzG/p9Y +sef1g9/uPc3I4Z0USruO5v7mJi6h6cS7+jhpAhvpX3GmgfiTemXxyVxvYcvCLSrM +gmm3SR61npNMA7yC2OdcbqtvefjM1x4x7AoEeDvUkULOCDWvYUyYkuCZHYubl1mS +Rtcp9rxzky2tjdp8CHySBa9Kz9LEjWdFGky7g3vSyqZtw6tkK5CTwMPb9aHwiEq1 +yDWCbqAPPnb300dXSqx7z3AcsxBi/lpCs79fQS1vSJ1/L9POpYxX0SMtffkR5+Bl +Mkg9dZUVenfbP4n40FdMypTTX2KJiMkc8+f0+Tz9QQKBgQDrdT7xs0lqMeZWp2rD +y2KLzRl0iO/+yKse+BgkeBggZ/Vh2TeF9ylTRYdmimxkJ8eZRipTr0F64BS/LPEk +RgWviuf9dUl0gyuhYTOJgUw1wcgtB6e+04UKEUsQW6JoNZekkM+xTa7FskLmlJ4J +zzowjF1lgJeEyX1tvWXKIVe+/QKBgQDUzy5nZoUqBrVTYxL2uadIUh8oiBKPU93U +Gz3DUq90yfDa7lFhwMRQRXfNqGUy6tshsaF4fT1b62hZDSz1OH3h/y1LKQOdF5kc +JJyk/4b7NJna16kwBLzWje5SjQKr51aQWU8JftZ5/8uck2j7vMi+mgwzpG45J7kv +Q1I5decBOQKBgFq1sKotB/uBfdukY91KXYy+VzAuEUd2x3YG3kYufhz97+riZCGY +NrN99cvrSBbNvHewMF5NBkzwRw3foob28vnN6dIbfVEFt6lUaSZwSYvsO9IdQOKj +Wn2ma+TBaK/89Y7QuzLzWoGPS3bJipj83M4XRWP1RmpBtbCxZqWYctWBAoGAMZPi +16wGsffGHpsiO+CcnDilkafByypar6N5DBwjTC4PsrF6vC9QjPLiKkNk8CvOyVa8 +q3lh5hw9vyFWq/pxOUldn/j6Iorw3KGa7MWrCLMEdPtxKwKvi7ydHRZE3Q+UFyT3 +SNsH1HxHTz74Yk1k5yK0XQOduisK9XvVmBVjr+ECgYEAyoSbo/1cyLKWgrIr0K/f +stiKL9SmBmYbaGaxtQToB5Hnqso7Hz5YEDlrcr8s1ukEFghgeNYuDYw3ZKKGGfZm +yVQKAt8ouoO8rfkLrtt0H+/0uJgouhewDEqf/O+MfzwDnFcT89J5ZTEf+9n6pjry +fuiQnuwEsPYGCCFuWWlrdHQ= +-----END PRIVATE KEY----- diff --git a/src/tests/publickey.crt b/src/tests/publickey.crt new file mode 100644 index 0000000..bd89332 --- /dev/null +++ b/src/tests/publickey.crt @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAw7umnDmvt2ntktJZaeaD +LF4wTHeUCXkCQnGOUPTQCExdlPVQcAIjH9sJmk2dWllhRkm/81nn+x8dXqjYbCvT +GC/kHSYodiPiqTLQ1pu4YcvRbQh1XPYtc/T67l29KJtow1i7gZD3QqiWUwufDm2S +poC+Dh+RdUL+SUf2V9tToy6JVzyaNbKJy7/ZpYLn74VJpwte6J0kqhSwQJ4VHnY2 +33Zy0oZKdMWvBtJ1uy7OyHWscqPDOUtjPmsyciyPO3qo4389MiFtAJvPdJkWvNYT +tg/mDXFQNsCBPTBCP4nuPNGMS0NFRwo1+A3FYq+HHhMcrGJHS/FSvlNeIDTuu5OD +VQIDAQAB +-----END PUBLIC KEY----- diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__1. no auth token.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__1. no auth token.snap new file mode 100644 index 0000000..049da8e --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__1. no auth token.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "16" + content-type: text/plain; charset=utf-8 +- No access token. diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__2. malformed auth token.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__2. malformed auth token.snap new file mode 100644 index 0000000..c581f38 --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__2. malformed auth token.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "21" + content-type: text/plain; charset=utf-8 +- Invalid access token. diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__2__login.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__2__login.snap new file mode 100644 index 0000000..a4f9e71 --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__2__login.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- content-length: "238" + content-type: text/html; charset=utf-8 + set-cookie: __Host-SessionlessXsrf=HL4qRFKUlBqkrPTvAQ6z-w; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000 + x-frame-options: DENY +- "
UN PW (temporary form)
" diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3. wrong auth token.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3. wrong auth token.snap new file mode 100644 index 0000000..0141afa --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3. wrong auth token.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "28" + content-type: text/plain; charset=utf-8 +- Invalid application session. diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__login.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__login.snap new file mode 100644 index 0000000..3082a19 --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__login.snap @@ -0,0 +1,10 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- content-length: "55" + content-type: text/plain; charset=utf-8 + location: "/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge=LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M&code_challenge_method=S256&nonce=noncey" + set-cookie: __Host-LoginSession=Glh_a6j2xs7ryaJWefPsoW59L7xq6QokAzGh-zEcOxY; HttpOnly; SameSite=Strict; Secure; Path=/; Max-Age=43200000 + x-frame-options: DENY +- Logged in. Redirecting you back to what you were doing. diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__token_no_code.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__token_no_code.snap new file mode 100644 index 0000000..53b3600 --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__3__token_no_code.snap @@ -0,0 +1,10 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, json)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "75" + content-type: application/json +- error: invalid_request + error_description: "`code` parameter missing." diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__auth.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__auth.snap new file mode 100644 index 0000000..0f70f51 --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__auth.snap @@ -0,0 +1,8 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- content-length: "288" + content-type: text/html; charset=utf-8 + x-frame-options: DENY +- "hi robert, consent to AClient?
" diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__token_conflict.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__token_conflict.snap new file mode 100644 index 0000000..6b479d4 --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__token_conflict.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "124" + content-type: application/json +- "{\"error\":\"invalid_grant\",\"error_description\":\"Auth code has been redeemed multiple times! This could mean something nasty.\"}" diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__token_malformed_code.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__token_malformed_code.snap new file mode 100644 index 0000000..01444ff --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__4__token_malformed_code.snap @@ -0,0 +1,10 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, json)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "77" + content-type: application/json +- error: invalid_request + error_description: "`code` parameter malformed." diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__auth.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__auth.snap new file mode 100644 index 0000000..b59eedf --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__auth.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- content-length: "46" + content-type: text/plain; charset=utf-8 + location: "http://aclient.example.com/redirect?code=UnLS_bGq0ZB4szozTRCJIG-37ibG08zK&state=wombat&iss=http%3A%2F%2Fissuer.example.com" + x-frame-options: DENY +- Authorisation succeeded; redirecting you back. diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__token_no_verifier.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__token_no_verifier.snap new file mode 100644 index 0000000..23a3f3b --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__5__token_no_verifier.snap @@ -0,0 +1,10 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, json)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "84" + content-type: application/json +- error: invalid_request + error_description: "`code_verifier` parameter missing." diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token.snap new file mode 100644 index 0000000..270771b --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token.snap @@ -0,0 +1,14 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, json)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "803" + content-type: application/json +- access_token: pvgYf08qA_ctEIhMP4DFQzbxjiCx8qfgi4cATwGsH9Q + expires_in: 31536000 + id_token: eyJ0eXAiOiJKV1QiLCJraWQiOiJ0aGVrZXkiLCJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJodHRwOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwic3ViIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwIiwiYXVkIjoiYWNsaWVudCIsImV4cCI6MzE1MzYwMDAzMCwiaWF0IjozMCwiYXV0aF90aW1lIjowLCJub25jZSI6Im5vbmNleSJ9.QQEhDgAcF2vBg2J6ledDzk_4ks4GyquJgMSE4KUREtTUVZbLpfa52sro8lPiBnFPCOz_DkSfpm4OQq8429mwcqfoyS-uBjtgPq7eij7kOa3BTrb9eC8rScGuDX0wJ9XZV-v0f3dun_sYhvH3smLqPoTF4wxtgT5b_2SCmnuqL2cmKN-GFox4mjmdoPzQxhAyTKlj_HkGHQjkl-nP96-71QeM5KwyLQes_OWU2HSEt9uiemUsEr4pMPv-po7QkrU5p2sJc5udcaUAOtuV9tpt5qg8P9TPWYo4M1GbMsbTyWYDhsmtusNKB6N2srwZwB9QwgE4DxoeoqmKlJf0BGF4Bg + refresh_token: _AHNodjMMCQJw2Bq3bsvDSZnjI4FEB2DDMK8ZHgDej8 + scope: openid + token_type: Bearer diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token_wrong_verifier.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token_wrong_verifier.snap new file mode 100644 index 0000000..709152b --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__6__token_wrong_verifier.snap @@ -0,0 +1,10 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, json)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "74" + content-type: application/json +- error: invalid_grant + error_description: Code challenge is invalid. diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__7__userinfo.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__7__userinfo.snap new file mode 100644 index 0000000..133c97f --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__7__userinfo.snap @@ -0,0 +1,11 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, json)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "92" + content-type: application/json +- name: robert + preferred_username: robert + sub: 00000000-0000-0000-0000-000000000000 diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__discovery_endpoint.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__discovery_endpoint.snap new file mode 100644 index 0000000..0a55506 --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__discovery_endpoint.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "505" + content-type: application/json +- "{\"issuer\":\"http://idcoop.example.com\",\"authorization_endpoint\":\"http://idcoop.example.com/oidc/auth\",\"token_endpoint\":\"http://idcoop.example.com/oidc/token\",\"userinfo_endpoint\":\"http://idcoop.example.com/oidc/userinfo\",\"jwks_uri\":\"http://idcoop.example.com/oidc/jwks\",\"scopes_supported\":[\"openid\"],\"response_types_supported\":[\"code\"],\"response_modes_supported\":[\"query\"],\"grant_types_supported\":[\"authorization_code\"],\"subject_types_supported\":[\"public\"],\"id_token_signing_alg_values_supported\":[\"RS256\"]}" diff --git a/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__jwks_endpoint.snap b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__jwks_endpoint.snap new file mode 100644 index 0000000..f4c4c5e --- /dev/null +++ b/src/tests/snapshots/idcoop__tests__test_oidc_auth_flow__jwks_endpoint.snap @@ -0,0 +1,9 @@ +--- +source: src/tests/test_oidc_auth_flow.rs +expression: "(headers, text)" +--- +- access-control-allow-origin: "*" + access-control-expose-headers: "*" + content-length: "425" + content-type: application/json +- "{\"keys\":[{\"kty\":\"RSA\",\"n\":\"w7umnDmvt2ntktJZaeaDLF4wTHeUCXkCQnGOUPTQCExdlPVQcAIjH9sJmk2dWllhRkm_81nn-x8dXqjYbCvTGC_kHSYodiPiqTLQ1pu4YcvRbQh1XPYtc_T67l29KJtow1i7gZD3QqiWUwufDm2SpoC-Dh-RdUL-SUf2V9tToy6JVzyaNbKJy7_ZpYLn74VJpwte6J0kqhSwQJ4VHnY233Zy0oZKdMWvBtJ1uy7OyHWscqPDOUtjPmsyciyPO3qo4389MiFtAJvPdJkWvNYTtg_mDXFQNsCBPTBCP4nuPNGMS0NFRwo1-A3FYq-HHhMcrGJHS_FSvlNeIDTuu5ODVQ\",\"e\":\"AQAB\",\"use\":\"sig\",\"kid\":\"thekey\",\"alg\":\"RS256\"}]}" diff --git a/src/tests/test_cli.rs b/src/tests/test_cli.rs new file mode 100644 index 0000000..1413558 --- /dev/null +++ b/src/tests/test_cli.rs @@ -0,0 +1,164 @@ +use rstest::rstest; + +use crate::cli::{handle_user_command, UserCommand}; + +use super::basic_system; + +#[rstest] +#[tokio::test] +async fn test_cli_add_user() { + let sys = basic_system().await; + + handle_user_command( + UserCommand::Add { + username: "jonathan".to_owned(), + locked: true, + }, + &sys.config, + &sys.store, + ) + .await + .unwrap(); + + let _: () = sys + .store + .txn(|mut txn| { + Box::pin(async move { + let user = txn.lookup_user_by_name("jonathan".to_owned()).await?; + + assert!(user.unwrap().locked); + Ok(()) + }) + }) + .await + .unwrap(); +} + +#[rstest] +#[tokio::test] +async fn test_cli_lock_and_unlock_user() { + let sys = basic_system().await; + + handle_user_command( + UserCommand::Add { + username: "jonathan".to_owned(), + locked: false, + }, + &sys.config, + &sys.store, + ) + .await + .unwrap(); + + let _: () = sys + .store + .txn(|mut txn| { + Box::pin(async move { + let user = txn.lookup_user_by_name("jonathan".to_owned()).await?; + + assert!(!user.unwrap().locked); + Ok(()) + }) + }) + .await + .unwrap(); + + handle_user_command( + UserCommand::Lock { + username: "jonathan".to_owned(), + }, + &sys.config, + &sys.store, + ) + .await + .unwrap(); + + let _: () = sys + .store + .txn(|mut txn| { + Box::pin(async move { + let user = txn.lookup_user_by_name("jonathan".to_owned()).await?; + + assert!(user.unwrap().locked); + Ok(()) + }) + }) + .await + .unwrap(); + + handle_user_command( + UserCommand::Unlock { + username: "jonathan".to_owned(), + }, + &sys.config, + &sys.store, + ) + .await + .unwrap(); + + let _: () = sys + .store + .txn(|mut txn| { + Box::pin(async move { + let user = txn.lookup_user_by_name("jonathan".to_owned()).await?; + + assert!(!user.unwrap().locked); + Ok(()) + }) + }) + .await + .unwrap(); +} + +#[rstest] +#[tokio::test] +async fn test_cli_del_user() { + let sys = basic_system().await; + + handle_user_command( + UserCommand::Add { + username: "jonathan".to_owned(), + locked: true, + }, + &sys.config, + &sys.store, + ) + .await + .unwrap(); + + let _: () = sys + .store + .txn(|mut txn| { + Box::pin(async move { + let user = txn.lookup_user_by_name("jonathan".to_owned()).await?; + + assert!(user.unwrap().locked); + Ok(()) + }) + }) + .await + .unwrap(); + + handle_user_command( + UserCommand::Delete { + username: "jonathan".to_owned(), + }, + &sys.config, + &sys.store, + ) + .await + .unwrap(); + + let _: () = sys + .store + .txn(|mut txn| { + Box::pin(async move { + let user = txn.lookup_user_by_name("jonathan".to_owned()).await?; + + assert!(user.is_none()); + Ok(()) + }) + }) + .await + .unwrap(); +} diff --git a/src/tests/test_oidc_auth_flow.rs b/src/tests/test_oidc_auth_flow.rs new file mode 100644 index 0000000..91f979b --- /dev/null +++ b/src/tests/test_oidc_auth_flow.rs @@ -0,0 +1,472 @@ +//! Tests the OpenID Connect auth flow + +use std::collections::BTreeMap; + +use axum::http::StatusCode; +use axum_test_helper::{TestClient, TestResponse}; +use insta::assert_yaml_snapshot; + +use maplit::btreemap; +use sqlx::types::Uuid; + +use crate::{passwords::create_password_hash, tests::basic_system}; + +async fn dump_resp_text( + req_name: &str, + resp: TestResponse, +) -> (StatusCode, BTreeMap, String) { + let status = resp.status(); + // convert headers to a simple B-Tree map so they can be serialised in snapshots + // easily + let mut headers: BTreeMap = resp + .headers() + .clone() + .into_iter() + // skip headers that are duplicate keys. + // We don't care about theme for our purposes. + .filter_map(|(k, v)| Some((k?.to_string(), v.to_str().unwrap().to_owned()))) + .collect(); + // Remove date because it's not stable across tests! + headers.remove("date"); + // Remove vary because it has multiple values and we don't want to + // introduce instability into our tests by only allowing one through. + headers.remove("vary"); + let text = resp.text().await; + eprintln!("=== Response for {req_name} ==="); + eprintln!("Status: {status:?}"); + eprintln!("Headers: {headers:#?}"); + eprintln!("Body: {text:?}"); + eprintln!("=== End of response ==="); + (status, headers, text) +} + +/// Tests the full flow... +#[tokio::test] +async fn test_full_flow() { + let sys = basic_system().await; + + let uuid = Uuid::nil(); + let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap(); + let _: () = sys + .store + .txn(|txn| { + Box::pin(async move { + sqlx::query( + "INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id", + ).bind("robert").bind(uuid).bind(pwhash).bind(false) + .fetch_one(&mut **txn.txn) + .await.unwrap(); + Ok(()) + }) + }) + .await + .unwrap(); + + let client = TestClient::new(sys.web); + + ///// These requests are on behalf of the user's browser ///// + + const CODE_VERIFIER: &str = "verifying"; + // base64(sha256("verifying")) + const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M"; + + // 1. /auth request + let login_url = format!("/login?then=%2Foidc%2Fauth%3Fscope%3Dopenid%26client_id%3Daclient%26response_type%3Dcode%26state%3Dwombat%26redirect_uri%3Dhttp%3A%252F%252Faclient.example.com%252Fredirect%26code_challenge%3D{CODE_CHALLENGE}%26code_challenge_method%3DS256%26nonce%3Dnoncey"); + let resp = client.get(&format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey")).send().await; + let (status, headers, _text) = dump_resp_text("1. /auth request", resp).await; + assert_eq!(status, 302); + assert_eq!(headers.get("location").unwrap(), &login_url); + + // 2. /login request + let resp = client.get(&login_url).send().await; + let (status, headers, text) = dump_resp_text("2. /login request", resp).await; + assert_eq!(status, 200); + assert_yaml_snapshot!("2/login", (headers, text)); + + // 3. /login request with credentials + let resp = client + .post(&login_url) + .form(&btreemap! { + "username" => "robert", + "password" => "secret", + "xsrf" => "HL4qRFKUlBqkrPTvAQ6z-w", + }) + .header("Cookie", "__Host-SessionlessXsrf=HL4qRFKUlBqkrPTvAQ6z-w") + // /login is rate-limited by IP source and needs an IP + .header("X-Forwarded-For", "0.0.0.0") + .send() + .await; + let (status, headers, text) = dump_resp_text("3. /login request with credentials", resp).await; + assert_eq!(status, 302); + let auth_loc = headers.get("location").unwrap().to_owned(); + assert_yaml_snapshot!("3/login", (headers, text)); + + // 4. /auth request + let resp = client + .get(&auth_loc) + .header( + "Cookie", + "__Host-LoginSession=Glh_a6j2xs7ryaJWefPsoW59L7xq6QokAzGh-zEcOxY", + ) + .send() + .await; + let (status, headers, text) = dump_resp_text("4. GET /auth after login", resp).await; + assert_eq!(status, 200); + assert_yaml_snapshot!("4/auth", (headers, text)); + + sys.clock.set_time(30); + + // 5. /auth request with confirmation + let resp = client + .post(&auth_loc) + .header( + "Cookie", + "__Host-LoginSession=Glh_a6j2xs7ryaJWefPsoW59L7xq6QokAzGh-zEcOxY", + ) + .form(&btreemap! { + "action" => "accept", + "xsrf" => "0.JpKyqkWckzF6w6btxX2RXv_MlxgOfoYOZJknydValkc", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("5. POST /auth after confirmation", resp).await; + assert_eq!(status, 302); + // Note this snapshot includes a Location: header back to the client application + assert_yaml_snapshot!("5/auth", (headers, text)); + + ///// At this point, we make requests on behalf of the client ///// + + // 6. /token request + let resp = client + .post("/oidc/token") + .header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB") + .form(&btreemap! { + "code" => "UnLS_bGq0ZB4szozTRCJIG-37ibG08zK", + "code_verifier" => CODE_VERIFIER, + "grant_type" => "authorization_code", + "redirect_uri" => "http://aclient.example.com/redirect", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("6. POST /token", resp).await; + assert_eq!(status, 200); + let json: BTreeMap = serde_json::from_str(&text).unwrap(); + assert_yaml_snapshot!("6/token", (headers, json)); + + // 7. /userinfo request + let resp = client + .get("/oidc/userinfo") + .header( + "Authorization", + "Bearer pvgYf08qA_ctEIhMP4DFQzbxjiCx8qfgi4cATwGsH9Q", + ) + .send() + .await; + let (status, headers, text) = dump_resp_text("7. /userinfo", resp).await; + assert_eq!(status, 200); + let json: BTreeMap = serde_json::from_str(&text).unwrap(); + assert_yaml_snapshot!("7/userinfo", (headers, json)); +} + +#[tokio::test] +async fn test_jwks_endpoint() { + let sys = basic_system().await; + let client = TestClient::new(sys.web); + let resp = client.get("/oidc/jwks").send().await; + let (status, headers, text) = dump_resp_text("/jwks", resp).await; + assert_eq!(status, 200); + assert_yaml_snapshot!((headers, text)); +} + +#[tokio::test] +async fn test_discovery_endpoint() { + let sys = basic_system().await; + let client = TestClient::new(sys.web); + let resp = client.get("/.well-known/openid-configuration").send().await; + let (status, headers, text) = dump_resp_text("discovery", resp).await; + assert_eq!(status, 200); + assert_yaml_snapshot!((headers, text)); +} + +#[tokio::test] +async fn test_userinfo_bad_auth() { + let sys = basic_system().await; + let client = TestClient::new(sys.web); + + // 1. no auth token + let resp = client.get("/oidc/userinfo").send().await; + let (status, headers, text) = dump_resp_text("1. no auth token", resp).await; + assert_eq!(status, 401); + assert_yaml_snapshot!("1. no auth token", (headers, text)); + + // 2. malformed access token + let resp = client + .get("/oidc/userinfo") + .header("Authorization", "Bearer ++++") + .send() + .await; + let (status, headers, text) = dump_resp_text("2. malformed auth token", resp).await; + assert_eq!(status, 401); + assert_yaml_snapshot!("2. malformed auth token", (headers, text)); + + // 3. wrong access token + let resp = client + .get("/oidc/userinfo") + .header("Authorization", "Bearer aaaa") + .send() + .await; + let (status, headers, text) = dump_resp_text("3. wrong auth token", resp).await; + assert_eq!(status, 401); + assert_yaml_snapshot!("3. wrong auth token", (headers, text)); +} + +/// Tests error conditions in the /token endpoint +#[tokio::test] +async fn test_token_errors() { + let sys = basic_system().await; + + let uuid = Uuid::nil(); + let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap(); + let _: () = sys + .store + .txn(|txn| { + Box::pin(async move { + sqlx::query( + "INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id", + ).bind("robert").bind(uuid).bind(pwhash).bind(false) + .fetch_one(&mut **txn.txn) + .await.unwrap(); + Ok(()) + }) + }) + .await + .unwrap(); + + let client = TestClient::new(sys.web); + + ///// These requests are on behalf of the user's browser ///// + + const CODE_VERIFIER: &str = "verifying"; + // base64(sha256("verifying")) + const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M"; + + // 1. /login request with credentials + let resp = client + .post("/login") + .form(&btreemap! { + "username" => "robert", + "password" => "secret", + "xsrf" => "HL4qRFKUlBqkrPTvAQ6z-w", + }) + .header("Cookie", "__Host-SessionlessXsrf=HL4qRFKUlBqkrPTvAQ6z-w") + // /login is rate-limited by IP source and needs an IP + .header("X-Forwarded-For", "0.0.0.0") + .send() + .await; + let (status, _headers, _text) = + dump_resp_text("1. /login request with credentials", resp).await; + assert_eq!(status, 302); + + let auth_loc = format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey"); + + // 2. /auth request with confirmation + let resp = client + .post(&auth_loc) + .header( + "Cookie", + "__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE", + ) + .form(&btreemap! { + "action" => "accept", + "xsrf" => "0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y", + }) + .send() + .await; + let (status, _headers, _text) = dump_resp_text("2. POST /auth after confirmation", resp).await; + assert_eq!(status, 302); + eprintln!("{:?}", _text); + + ///// At this point, we make requests on behalf of the client ///// + + // 3. /token request with no code + let resp = client + .post("/oidc/token") + .header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB") + .form(&btreemap! { + "code_verifier" => CODE_VERIFIER, + "grant_type" => "authorization_code", + "redirect_uri" => "http://aclient.example.com/redirect", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("3. /token no code", resp).await; + assert_eq!(status, 400); + let json: BTreeMap = serde_json::from_str(&text).unwrap(); + assert_yaml_snapshot!("3/token_no_code", (headers, json)); + + // 4. /token request with malformed code (not long enough) + let resp = client + .post("/oidc/token") + .header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB") + .form(&btreemap! { + "code" => "aaaa", + "code_verifier" => CODE_VERIFIER, + "grant_type" => "authorization_code", + "redirect_uri" => "http://aclient.example.com/redirect", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("4. /token malformed code", resp).await; + assert_eq!(status, 400); + let json: BTreeMap = serde_json::from_str(&text).unwrap(); + assert_yaml_snapshot!("4/token_malformed_code", (headers, json)); + + // 5. /token request with no code_verifier + let resp = client + .post("/oidc/token") + .header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB") + .form(&btreemap! { + "code" => "LRtIBH5rO3O7hwWaF_UkuFJy0v2xqtGQ", + "grant_type" => "authorization_code", + "redirect_uri" => "http://aclient.example.com/redirect", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("5. /token no verifier", resp).await; + assert_eq!(status, 400); + let json: BTreeMap = serde_json::from_str(&text).unwrap(); + assert_yaml_snapshot!("5/token_no_verifier", (headers, json)); + + // 6. /token request with wrong code_verifier + let resp = client + .post("/oidc/token") + .header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB") + .form(&btreemap! { + "code" => "LRtIBH5rO3O7hwWaF_UkuFJy0v2xqtGQ", + "code_verifier" => "i'm wrong", + "grant_type" => "authorization_code", + "redirect_uri" => "http://aclient.example.com/redirect", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("6. /token wrong verifier", resp).await; + assert_eq!(status, 400); + let json: BTreeMap = serde_json::from_str(&text).unwrap(); + assert_yaml_snapshot!("6/token_wrong_verifier", (headers, json)); +} + +/// Tests double-requesting a /token (auth code conflict) +#[tokio::test] +async fn test_token_conflict() { + let sys = basic_system().await; + + let uuid = Uuid::nil(); + let pwhash = create_password_hash("secret", &sys.config.password_hashing).unwrap(); + let _: () = sys + .store + .txn(|txn| { + Box::pin(async move { + sqlx::query( + "INSERT INTO users (user_name, user_id, created_at_utc, password_hash, locked) VALUES ($1, $2, NOW(), $3, $4) RETURNING user_id", + ).bind("robert").bind(uuid).bind(pwhash).bind(false) + .fetch_one(&mut **txn.txn) + .await.unwrap(); + Ok(()) + }) + }) + .await + .unwrap(); + + let client = TestClient::new(sys.web); + + ///// These requests are on behalf of the user's browser ///// + + const CODE_VERIFIER: &str = "verifying"; + // base64(sha256("verifying")) + const CODE_CHALLENGE: &str = "LeU9Sprdh-i2mzasKGh8-hmbnmzk48l3Siw390dKY3M"; + + // 1. /login request with credentials + let resp = client + .post("/login") + .form(&btreemap! { + "username" => "robert", + "password" => "secret", + "xsrf" => "HL4qRFKUlBqkrPTvAQ6z-w", + }) + .header("Cookie", "__Host-SessionlessXsrf=HL4qRFKUlBqkrPTvAQ6z-w") + // /login is rate-limited by IP source and needs an IP + .header("X-Forwarded-For", "0.0.0.0") + .send() + .await; + let (status, _headers, _text) = + dump_resp_text("1. /login request with credentials", resp).await; + assert_eq!(status, 302); + + let auth_loc = format!("/oidc/auth?scope=openid&client_id=aclient&response_type=code&state=wombat&redirect_uri=http:%2F%2Faclient.example.com%2Fredirect&code_challenge={CODE_CHALLENGE}&code_challenge_method=S256&nonce=noncey"); + + // 2. /auth request with confirmation + let resp = client + .post(&auth_loc) + .header( + "Cookie", + "__Host-LoginSession=HL4qRFKUlBqkrPTvAQ6z-xpYf2uo9sbO68miVnnz7KE", + ) + .form(&btreemap! { + "action" => "accept", + "xsrf" => "0.48qkqIorf3dyk1LgVQwyNT82yDHyqHbXge09Rvfsz8Y", + }) + .send() + .await; + let (status, _headers, _text) = dump_resp_text("2. POST /auth after confirmation", resp).await; + assert_eq!(status, 302); + eprintln!("{:?}", _text); + + ///// At this point, we make requests on behalf of the client ///// + + // 3. /token request (successful) + let resp = client + .post("/oidc/token") + .header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB") + .form(&btreemap! { + "code" => "LRtIBH5rO3O7hwWaF_UkuFJy0v2xqtGQ", + "code_verifier" => CODE_VERIFIER, + "grant_type" => "authorization_code", + "redirect_uri" => "http://aclient.example.com/redirect", + }) + .send() + .await; + let (status, _headers, text) = dump_resp_text("3. POST /token", resp).await; + assert_eq!(status, 200); + let json: BTreeMap = serde_json::from_str(&text).unwrap(); + let access_token = json + .get("access_token") + .unwrap() + .as_str() + .unwrap() + .to_owned(); + + // 4. /token request (conflicting) + let resp = client + .post("/oidc/token") + .header("Authorization", "Basic YWNsaWVudDpzZWNyZXRB") + .form(&btreemap! { + "code" => "LRtIBH5rO3O7hwWaF_UkuFJy0v2xqtGQ", + "code_verifier" => CODE_VERIFIER, + "grant_type" => "authorization_code", + "redirect_uri" => "http://aclient.example.com/redirect", + }) + .send() + .await; + let (status, headers, text) = dump_resp_text("4. POST /token (conflict)", resp).await; + assert_eq!(status, 400); + assert_yaml_snapshot!("4/token_conflict", (headers, text)); + + // 5. /userinfo (using the access token that should now have expired) + let resp = client + .get("/oidc/userinfo") + .header("Authorization", format!("Bearer {access_token}")) + .send() + .await; + let (status, _headers, _text) = dump_resp_text("7. /userinfo", resp).await; + assert_eq!(status, 401); +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..ab0a764 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,145 @@ +//! Miscellaneous utilities + +#[cfg(not(test))] +pub use self::real_utils::{Clock, RandGen}; + +#[cfg(test)] +pub use self::test_utils::{Clock, RandGen}; + +#[cfg(not(test))] +mod real_utils { + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use chrono::{DateTime, Utc}; + use rand::{thread_rng, RngCore}; + + /// A source of random numbers that can be faked for tests. + #[derive(Clone)] + pub struct RandGen; + + impl RngCore for RandGen { + fn next_u32(&mut self) -> u32 { + thread_rng().next_u32() + } + + fn next_u64(&mut self) -> u64 { + thread_rng().next_u64() + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + thread_rng().fill_bytes(dest) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { + thread_rng().try_fill_bytes(dest) + } + } + + /// A source of time that can be faked for tests. + #[derive(Clone)] + pub struct Clock; + + impl Clock { + /// Returns the current time as a `DateTime`. + pub fn now_utc(&self) -> DateTime { + Utc::now() + } + + /// Returns the current time as a i64 of seconds since the Unix epoch. + pub fn now_timestamp(&self) -> i64 { + Utc::now().timestamp() + } + + /// Sleep until the given timestamp. + pub async fn sleep_until(&self, until_ts: i64) { + if until_ts < 0 { + return; + } + let sleep_until = UNIX_EPOCH + Duration::from_secs(until_ts as u64); + let now = SystemTime::now(); + let sleep_for = sleep_until + .duration_since(now) + .unwrap_or(Duration::from_secs(0)); + tokio::time::sleep(sleep_for).await + } + } +} + +#[cfg(test)] +mod test_utils { + use std::sync::atomic::AtomicU64; + use std::sync::Arc; + use std::sync::Mutex; + use std::time::Duration; + + use chrono::{DateTime, TimeZone, Utc}; + use rand::{RngCore, SeedableRng}; + use rand_xoshiro::Xoshiro256StarStar; + + #[derive(Clone)] + pub struct RandGen(Arc>); + + impl RandGen { + #[allow(clippy::new_without_default)] + pub fn new() -> RandGen { + RandGen(Arc::new(Mutex::new(Xoshiro256StarStar::seed_from_u64( + 424242, + )))) + } + } + + impl RngCore for RandGen { + fn next_u32(&mut self) -> u32 { + let mut rng = self.0.lock().unwrap(); + rng.next_u32() + } + + fn next_u64(&mut self) -> u64 { + let mut rng = self.0.lock().unwrap(); + rng.next_u64() + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + let mut rng = self.0.lock().unwrap(); + rng.fill_bytes(dest) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { + let mut rng = self.0.lock().unwrap(); + rng.try_fill_bytes(dest) + } + } + + #[derive(Clone)] + pub struct Clock(pub Arc); + + impl Clock { + pub fn new_test() -> Self { + Clock(Arc::new(AtomicU64::new(0))) + } + pub fn set_time(&self, new: u64) { + self.0.store(new, std::sync::atomic::Ordering::Relaxed); + } + + pub fn now_utc(&self) -> DateTime { + Utc.timestamp_opt(self.0.load(std::sync::atomic::Ordering::Relaxed) as i64, 0) + .earliest() + .unwrap() + } + + pub fn now_timestamp(&self) -> i64 { + self.0.load(std::sync::atomic::Ordering::Relaxed) as i64 + } + + pub async fn sleep_until(&self, until_ts: i64) { + if until_ts < 0 { + return; + } + // Wait for time to advance past the requested timestamp + // TODO write a better test sleep implementation + while self.now_timestamp() < until_ts { + tokio::time::sleep(Duration::from_millis(1)).await; + } + } + } +} diff --git a/src/web.rs b/src/web.rs index 4475987..281ff73 100644 --- a/src/web.rs +++ b/src/web.rs @@ -24,16 +24,16 @@ use axum::{ routing::{get, post}, Extension, Router, }; -use eyre::Context; use governor::{clock::QuantaClock, state::keyed::DashMapStateStore, RateLimiter}; use hornbeam::{initialise_template_manager, make_template_manager}; use tower_cookies::CookieManagerLayer; use tower_http::{cors::CorsLayer, set_header::SetResponseHeaderLayer, trace::TraceLayer}; -use tracing::{error, info}; +use tracing::error; use crate::{ config::{Configuration, RatelimiterConfig, RatelimitsConfig, SecretConfig}, store::IdCoopStore, + utils::{Clock, RandGen}, web::{ login::{get_login, post_login, PasswordHashInflightLimiter}, oauth_openid::{ @@ -55,14 +55,15 @@ make_template_manager! { }; } -/// Serves, on the bind address specified, the HTTP service -/// including a user interface and any OAuth, OpenID Connect and custom APIs. -pub async fn serve( - bind: SocketAddr, +/// Make an axum `Router` but do not bind it to a port. +/// This exposition allows us to perform integration testing easily. +pub(crate) async fn make_router( store: Arc, config: Arc, secrets: Arc, -) -> eyre::Result<()> { + clock: Clock, + randgen: RandGen, +) -> eyre::Result { initialise_template_manager!(TEMPLATING); let client_ip_source = config.listen.client_ip_source.clone(); @@ -124,7 +125,26 @@ pub async fn serve( .layer(Extension(Arc::new(PasswordHashInflightLimiter::new(1)))) .layer(client_ip_source.into_extension()) .layer(Extension(Arc::new(ratelimiters))) - .layer(Extension(VolatileCodeStore::default())); + .layer(Extension(VolatileCodeStore::new(clock.clone()))) + .layer(Extension(clock)) + .layer(Extension(randgen)); + + Ok(router) +} + +/// Serves, on the bind address specified, the HTTP service +/// including a user interface and any OAuth, OpenID Connect and custom APIs. +#[cfg(not(test))] +pub async fn serve( + bind: SocketAddr, + store: Arc, + config: Arc, + secrets: Arc, +) -> eyre::Result<()> { + use eyre::Context; + use tracing::info; + + let router = make_router(store, config, secrets, Clock, RandGen).await?; info!("Listening on {bind:?}"); axum::Server::try_bind(&bind) @@ -213,7 +233,7 @@ impl Ratelimiters { /// Do some housekeeping if it hasn't been done recently. pub fn housekeeping(&self) { let Ok(now) = SystemTime::now().duration_since(UNIX_EPOCH) else { - return + return; }; let now = (now.as_secs() >> 2) as u32; diff --git a/src/web/login.rs b/src/web/login.rs index 93bfc25..cce0727 100644 --- a/src/web/login.rs +++ b/src/web/login.rs @@ -9,8 +9,8 @@ use std::{ use async_trait::async_trait; use axum::{ extract::{FromRequestParts, Query}, - headers::Cookie, - http::{request::Parts, uri::PathAndQuery, HeaderValue, StatusCode}, + headers::Cookie as CookieHeader, + http::{request::Parts, uri::PathAndQuery, StatusCode}, response::{Html, IntoResponse, Response}, Extension, Form, TypedHeader, }; @@ -21,17 +21,18 @@ use chrono::{DateTime, Duration, TimeZone, Utc}; use eyre::eyre; use eyre::{bail, Context, ContextCompat}; use governor::Jitter; -use rand::{thread_rng, Rng}; +use rand::Rng; use serde::Deserialize; use sqlx::types::Uuid; use tokio::sync::Semaphore; -use tower_cookies::Cookies; +use tower_cookies::{Cookie, Cookies}; use tracing::error; use crate::{ config::{Configuration, PasswordHashingConfig}, passwords::{check_hash, create_password_hash}, store::IdCoopStore, + utils::RandGen, }; use super::{sessionless_xsrf, Ratelimiters, WebResult}; @@ -203,17 +204,15 @@ where type Rejection = (StatusCode, &'static str); async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let Ok(cookies) = TypedHeader::::from_request_parts(parts, state).await else { + let Ok(cookies) = TypedHeader::::from_request_parts(parts, state).await + else { return Err((StatusCode::UNAUTHORIZED, "No login session.")); }; let Some(cookie_val) = cookies.get("__Host-LoginSession").map(str::to_owned) else { return Err((StatusCode::UNAUTHORIZED, "No login session.")); }; let Ok(login_session_token) = BASE64_URL_SAFE_NO_PAD.decode(&cookie_val) else { - return Err(( - StatusCode::UNAUTHORIZED, - "Invalid login session token." - )); + return Err((StatusCode::UNAUTHORIZED, "Invalid login session token.")); }; if login_session_token.len() != LOGIN_SESSION_TOKEN_BYTES { return Err((StatusCode::UNAUTHORIZED, "Invalid login session token.")); @@ -268,11 +267,12 @@ pub async fn get_login( current_session: Option, Query(query): Query, cookies: Cookies, + Extension(mut randgen): Extension, ) -> Response { match current_session { Some(_session) => make_post_login_redirect(query.then), None => { - let xsrf_token = sessionless_xsrf::get_token(&cookies); + let xsrf_token = sessionless_xsrf::get_token(&cookies, &mut randgen); Html(format!("
UN PW (temporary form)
", xsrf_token)).into_response() } } @@ -335,6 +335,7 @@ pub async fn post_login( Extension(phil): Extension>, SecureClientIp(src_ip): SecureClientIp, Extension(ratelimiters): Extension>, + Extension(mut randgen): Extension, Form(form): Form, ) -> WebResult { ratelimiters.housekeeping(); @@ -344,7 +345,7 @@ pub async fn post_login( .await; if !sessionless_xsrf::check_token(&cookies, &form.xsrf) { // Invalid XSRF token: try again - return Ok(get_login(None, Query(query), cookies).await); + return Ok(get_login(None, Query(query), cookies, Extension(randgen)).await); } // retrieve user details @@ -400,11 +401,11 @@ pub async fn post_login( }; // Generate a login session token and store the hash in our database - let login_session_token = thread_rng().gen::<[u8; LOGIN_SESSION_TOKEN_BYTES]>(); + let login_session_token = randgen.gen::<[u8; LOGIN_SESSION_TOKEN_BYTES]>(); let login_session_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(login_session_token); let login_session_token_hash: [u8; LOGIN_SESSION_TOKEN_HASH_BYTES] = Blake2s256::digest(login_session_token).into(); - let xsrf_secret = thread_rng().gen::<[u8; LOGIN_SESSION_XSRF_SECRET_BYTES]>(); + let xsrf_secret = randgen.gen::<[u8; LOGIN_SESSION_XSRF_SECRET_BYTES]>(); // store session in the database store @@ -417,20 +418,16 @@ pub async fn post_login( .await .context("failed to store session in database")?; - let expiry_date = chrono::Utc::now() + chrono::Duration::days(500); - let expiry_date_rfc1123 = expiry_date.format("%a, %d %b %Y %H:%M:%S GMT"); - Ok(( - [( - "Set-Cookie", - HeaderValue::from_str(&format!( - "__Host-LoginSession={}; Path=/; HttpOnly; SameSite=Strict; Secure; Expires={}", - login_session_token_b64, expiry_date_rfc1123 - )) - .expect("no reason we should fail to make a cookie"), - )], - make_post_login_redirect(query.then), - ) - .into_response()) + cookies.add( + Cookie::build("__Host-LoginSession", login_session_token_b64.clone()) + .path("/") + .http_only(true) + .secure(true) + .same_site(tower_cookies::cookie::SameSite::Strict) + .max_age(time::Duration::days(500)) + .finish(), + ); + Ok(make_post_login_redirect(query.then)) } /// Make a redirect for once the user has logged in. diff --git a/src/web/oauth_openid/authorisation.rs b/src/web/oauth_openid/authorisation.rs index b3a011e..b16ff7e 100644 --- a/src/web/oauth_openid/authorisation.rs +++ b/src/web/oauth_openid/authorisation.rs @@ -9,7 +9,6 @@ use axum::{ Extension, Form, }; -use chrono::Utc; use eyre::{Context, ContextCompat}; use serde::{Deserialize, Serialize}; @@ -17,6 +16,7 @@ use tracing::{error, warn}; use crate::{ config::{Configuration, OidcClientConfiguration}, + utils::{Clock, RandGen}, web::{ login::LoginSession, make_login_redirect, @@ -69,6 +69,8 @@ pub async fn oidc_authorisation( login_session: Option, Extension(config): Extension>, Extension(code_store): Extension, + Extension(clock): Extension, + Extension(mut randgen): Extension, OriginalUri(uri): OriginalUri, ) -> Response { let Query(query) = match query { @@ -109,7 +111,7 @@ pub async fn oidc_authorisation( // If the application requires consent, then we should ask for that. if !client_config.skip_consent { - return show_consent_page(login_session, client_config, &config).await; + return show_consent_page(login_session, client_config, Extension(clock), &config).await; } // No consent needed: process the authorisation. @@ -120,6 +122,8 @@ pub async fn oidc_authorisation( client_id, client_config, &config, + &mut randgen, + &clock, &code_store, ) .await @@ -133,11 +137,14 @@ pub struct PostConsentForm { } /// `POST /oidc/auth` +#[allow(clippy::too_many_arguments)] pub async fn post_oidc_authorisation_consent( Query(query): Query, login_session: Option, Extension(config): Extension>, Extension(code_store): Extension, + Extension(clock): Extension, + Extension(mut randgen): Extension, OriginalUri(uri): OriginalUri, Form(form): Form, ) -> Response { @@ -152,11 +159,11 @@ pub async fn post_oidc_authorisation_consent( }; if login_session - .validate_xsrf_token(&form.xsrf, Utc::now()) + .validate_xsrf_token(&form.xsrf, clock.now_utc()) .is_err() { // XSRF token is not valid, so show the consent form again... - return show_consent_page(login_session, client_config, &config).await; + return show_consent_page(login_session, client_config, Extension(clock), &config).await; } match form.action.as_str() { @@ -167,6 +174,8 @@ pub async fn post_oidc_authorisation_consent( client_id, client_config, &config, + &mut randgen, + &clock, &code_store, ) .await @@ -233,10 +242,11 @@ fn validate_authorisation_basics<'a>( async fn show_consent_page( login_session: LoginSession, client_config: &OidcClientConfiguration, + Extension(clock): Extension, _config: &Configuration, ) -> Response { let xsrf_token = login_session - .generate_xsrf_token(Utc::now()) + .generate_xsrf_token(clock.now_utc()) .expect("must be able to create a XSRF token"); Html(format!( "hi {}, consent to {}?
", @@ -253,12 +263,15 @@ async fn show_consent_page( /// Preconditions: /// - any required consent from the user has now been obtained /// - query.request_uri has been validated as a safe redirect URI +#[allow(clippy::too_many_arguments)] async fn process_authorisation( query: AuthorisationQuery, login_session: LoginSession, client_id: String, _client_config: &OidcClientConfiguration, config: &Configuration, + randgen: &mut RandGen, + clock: &Clock, code_store: &VolatileCodeStore, ) -> Response { assert_eq!( @@ -271,7 +284,7 @@ async fn process_authorisation( // Generate a 192-bit random code, which fits into exactly 32 base64 characters. // This is an arbitrary choice left to us but I feel a 192-bit value is sufficiently random. - let code = AuthCode::generate_new_random(); + let code = AuthCode::generate_new_random(randgen); let code_base64url = code.to_string(); // Write down the code and other details in-memory with 10 minute expiry... @@ -289,7 +302,7 @@ async fn process_authorisation( user_id: login_session.user_id, user_login_session_id: login_session.login_session_id, }, - 0, + clock.now_timestamp() + 600, ); #[derive(Serialize)] diff --git a/src/web/oauth_openid/ext_codes.rs b/src/web/oauth_openid/ext_codes.rs index a92a0d7..4c626b7 100644 --- a/src/web/oauth_openid/ext_codes.rs +++ b/src/web/oauth_openid/ext_codes.rs @@ -2,10 +2,10 @@ use std::{ collections::{BTreeSet, HashMap}, + fmt::Debug, fmt::Display, str::FromStr, sync::{Arc, Mutex}, - time::{Duration, SystemTime, UNIX_EPOCH}, }; use base64::{display::Base64Display, prelude::BASE64_URL_SAFE_NO_PAD, Engine}; @@ -13,10 +13,12 @@ use rand::Rng; use sqlx::types::Uuid; use tokio::sync::Notify; +use crate::utils::{Clock, RandGen}; + /// Display shows the auth code as base64 (URL-safe non-padded). /// FromStr/parse parses the same format. #[derive(Clone, Hash, PartialEq, Eq, Ord, PartialOrd)] -pub struct AuthCode([u8; 24]); +pub struct AuthCode(pub [u8; 24]); /// Access token pub type AccessToken = [u8; 32]; @@ -37,6 +39,12 @@ impl Display for AuthCode { } } +impl Debug for AuthCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self}") + } +} + impl FromStr for AuthCode { type Err = &'static str; @@ -54,14 +62,15 @@ impl FromStr for AuthCode { impl AuthCode { /// Generate a new authorisation code using the thread's RNG - pub fn generate_new_random() -> Self { - Self(rand::thread_rng().gen::<[u8; 24]>()) + pub fn generate_new_random(randgen: &mut RandGen) -> Self { + Self(randgen.gen::<[u8; 24]>()) } } /// Binding between an authorisation code (ready to be redeemed) /// and both the user that authenticated to produce it /// as well as the OpenID Connect client that it is for. +#[derive(PartialEq, Eq, Debug)] pub struct AuthCodeBinding { /// ID of the OpenID Connect client pub client_id: String, @@ -104,7 +113,7 @@ struct VolatileCodeStoreInner { pub conflictable_codes: HashMap, /// Time when codes will expire - pub expire_codes_at: BTreeSet<(u64, AuthCode)>, + pub expire_codes_at: BTreeSet<(i64, AuthCode)>, } impl VolatileCodeStoreInner { @@ -140,7 +149,7 @@ impl VolatileCodeStoreInner { &mut self, auth_code: AuthCode, auth_code_binding: AuthCodeBinding, - expires_at: u64, + expires_at: i64, ) { self.redeemable_codes .insert(auth_code.clone(), auth_code_binding); @@ -148,13 +157,15 @@ impl VolatileCodeStoreInner { } /// Removes all expired auth codes and returns the time of the earliest next expiry, if present. - pub(self) fn handle_expiry(&mut self, now: u64) -> Option { + pub(self) fn handle_expiry(&mut self, now: i64) -> Option { loop { let (ts, _auth_code) = self.expire_codes_at.first()?; // Remove if expired if *ts <= now { - self.expire_codes_at.pop_first(); + let (_, auth_code) = self.expire_codes_at.pop_first().unwrap(); + self.redeemable_codes.remove(&auth_code); + self.conflictable_codes.remove(&auth_code); continue; } @@ -172,15 +183,16 @@ pub struct VolatileCodeStore { inner: Arc>, } -impl Default for VolatileCodeStore { - fn default() -> Self { +impl VolatileCodeStore { + /// Create a new instance. + pub fn new(clock: Clock) -> Self { let poke = Arc::new(Notify::new()); let inner: Arc> = Default::default(); { let poke = poke.clone(); let inner = inner.clone(); - tokio::spawn(Self::expirer(inner, poke)); + tokio::spawn(Self::expirer(inner, poke, clock)); } VolatileCodeStore { inner, poke } @@ -188,19 +200,15 @@ impl Default for VolatileCodeStore { } impl VolatileCodeStore { - async fn expirer(inner: Arc>, poke: Arc) { - let mut next_expiry: Option = None; + async fn expirer(inner: Arc>, poke: Arc, clock: Clock) { + let mut next_expiry: Option = None; loop { match next_expiry { Some(next_expiry) => { - let sleep_until = UNIX_EPOCH + Duration::from_secs(next_expiry); - let now = SystemTime::now(); - let sleep_for = sleep_until - .duration_since(now) - .unwrap_or(Duration::from_secs(60)); + let sleep_future = clock.sleep_until(next_expiry); tokio::select! { _ = poke.notified() => {}, - _ = tokio::time::sleep(sleep_for) => {}, + _ = sleep_future => {}, } } None => { @@ -208,14 +216,9 @@ impl VolatileCodeStore { } } - let now = SystemTime::now(); next_expiry = { let mut inner = inner.lock().unwrap(); - inner.handle_expiry( - now.duration_since(UNIX_EPOCH) - .expect("system clock before unix epoch") - .as_secs(), - ) + inner.handle_expiry(clock.now_timestamp()) }; } } @@ -234,7 +237,7 @@ impl VolatileCodeStore { } /// Add a new redeemable authorisation code. - pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: u64) { + pub fn add_redeemable(&self, auth_code: AuthCode, binding: AuthCodeBinding, expires_at: i64) { let mut inner = self.inner.lock().unwrap(); inner.add_redeemable(auth_code, binding, expires_at); drop(inner); @@ -244,6 +247,7 @@ impl VolatileCodeStore { } /// Possible outcomes of an attempt to redeem an authorisation code. +#[derive(PartialEq, Eq, Debug)] pub enum CodeRedemption { /// That auth code was not active Invalid, @@ -264,3 +268,212 @@ pub enum CodeRedemption { refresh_token_to_invalidate: RefreshTokenHash, }, } + +#[cfg(test)] +mod test { + use std::{str::FromStr, time::Duration}; + + use assert_matches2::assert_matches; + use rstest::{fixture, rstest}; + use sqlx::types::Uuid; + + use crate::{utils::Clock, web::oauth_openid::ext_codes::CodeRedemption}; + + use super::{AuthCode, AuthCodeBinding, VolatileCodeStore, VolatileCodeStoreInner}; + + const VALID_CODE: AuthCode = AuthCode([21; 24]); + + const USER_UUID: Uuid = Uuid::nil(); + const USER_LOGIN_SESSION_ID: i32 = 1347; + + #[fixture] + fn code_store() -> VolatileCodeStoreInner { + let mut vcs = VolatileCodeStoreInner::default(); + + vcs.add_redeemable( + VALID_CODE.clone(), + AuthCodeBinding { + client_id: "client_id".to_owned(), + redirect_uri: "https://client/redirect".to_owned(), + nonce: None, + code_challenge_method: "method".to_owned(), + code_challenge: "challenge".to_owned(), + user_id: USER_UUID, + user_login_session_id: USER_LOGIN_SESSION_ID, + }, + 128, + ); + vcs + } + + #[rstest] + fn test_redeem_nonexistent_code(mut code_store: VolatileCodeStoreInner) { + let ac = AuthCode([0; 24]); + + let result = code_store.redeem(&ac, [42; 32], [43; 32]); + + assert_eq!(result, CodeRedemption::Invalid); + } + + #[rstest] + fn test_redeem_real_code(mut code_store: VolatileCodeStoreInner) { + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + + assert_matches!(result, CodeRedemption::Valid { binding }); + + assert_eq!(binding.client_id, "client_id"); + } + + #[rstest] + fn test_redeem_code_twice(mut code_store: VolatileCodeStoreInner) { + // redeem a first time + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_matches!(result, CodeRedemption::Valid { .. }); + + // redeem a second time + let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]); + assert_matches!( + result, + CodeRedemption::Conflicted { + access_token_to_invalidate, + refresh_token_to_invalidate + } + ); + + assert_eq!(access_token_to_invalidate, [42; 32]); + assert_eq!(refresh_token_to_invalidate, [43; 32]); + } + + #[rstest] + fn test_expire_and_redeem_not_expired_yet(mut code_store: VolatileCodeStoreInner) { + code_store.handle_expiry(127); + + // The code shouldn't expire yet. + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_matches!(result, CodeRedemption::Valid { .. }); + } + + #[rstest] + fn test_expire_and_redeem_expired_now(mut code_store: VolatileCodeStoreInner) { + code_store.handle_expiry(128); + + // The code should have expired + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_eq!(result, CodeRedemption::Invalid); + } + + #[rstest] + fn test_expire_and_redeem_conflict(mut code_store: VolatileCodeStoreInner) { + // redeem a first time + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_matches!(result, CodeRedemption::Valid { .. }); + + code_store.handle_expiry(127); + + // redeem a second time: the conflict should not have expired yet. + let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]); + assert_matches!(result, CodeRedemption::Conflicted { .. }); + + // redeem a third time to show the conflict doesn't get removed. + let result = code_store.redeem(&VALID_CODE, [1; 32], [2; 32]); + assert_matches!(result, CodeRedemption::Conflicted { .. }); + + code_store.handle_expiry(128); + + // The code and its conflict entry should have expired + let result = code_store.redeem(&VALID_CODE, [42; 32], [43; 32]); + assert_eq!(result, CodeRedemption::Invalid); + } + + #[test] + fn test_authcode_display() { + assert_eq!(VALID_CODE.to_string(), "FRUVFRUVFRUVFRUVFRUVFRUVFRUVFRUV"); + assert_eq!( + format!("{VALID_CODE:?}"), + "FRUVFRUVFRUVFRUVFRUVFRUVFRUVFRUV" + ); + } + + #[test] + fn test_authcode_from_string() { + assert_eq!( + AuthCode::from_str(&VALID_CODE.to_string()).unwrap(), + VALID_CODE + ); + } + + #[test] + fn test_authcode_from_string_wrong_size() { + // Not a valid base64 string + assert_eq!(AuthCode::from_str("a").unwrap_err(), "wrong size"); + + // Not 24 bytes when decoded + assert_eq!(AuthCode::from_str("abcd").unwrap_err(), "wrong size"); + } + + /// Basic test for the full-fat VolatileCodeStore, not just its inner logic + /// We can't easily cover everything here but may as well test the basics. + #[tokio::test] + async fn test_fullfat_store_basic() { + let clock = Clock::new_test(); + let vcs = VolatileCodeStore::new(clock.clone()); + + vcs.add_redeemable( + VALID_CODE.clone(), + AuthCodeBinding { + client_id: "client_id".to_owned(), + redirect_uri: "https://client/redirect".to_owned(), + nonce: None, + code_challenge_method: "method".to_owned(), + code_challenge: "challenge".to_owned(), + user_id: USER_UUID, + user_login_session_id: USER_LOGIN_SESSION_ID, + }, + i64::MAX, + ); + + assert_matches!( + vcs.redeem(&VALID_CODE, [1; 32], [2; 32]), + CodeRedemption::Valid { .. } + ); + + vcs.add_redeemable( + VALID_CODE.clone(), + AuthCodeBinding { + client_id: "client_id".to_owned(), + redirect_uri: "https://client/redirect".to_owned(), + nonce: None, + code_challenge_method: "method".to_owned(), + code_challenge: "challenge".to_owned(), + user_id: USER_UUID, + user_login_session_id: USER_LOGIN_SESSION_ID, + }, + // Expiry in the past: this should be auto-expired when poked and + // given a moment + 1, + ); + clock.set_time(2); + + vcs.add_redeemable( + AuthCode([0; 24]), + AuthCodeBinding { + client_id: "client_id".to_owned(), + redirect_uri: "https://client/redirect".to_owned(), + nonce: None, + code_challenge_method: "method".to_owned(), + code_challenge: "challenge".to_owned(), + user_id: USER_UUID, + user_login_session_id: USER_LOGIN_SESSION_ID, + }, + i64::MAX, + ); + + // Give a short time for the expiry to take place. + tokio::time::sleep(Duration::from_millis(2)).await; + + assert_matches!( + vcs.redeem(&VALID_CODE, [1; 32], [2; 32]), + CodeRedemption::Invalid + ); + } +} diff --git a/src/web/oauth_openid/token.rs b/src/web/oauth_openid/token.rs index 7a00e57..f0017c9 100644 --- a/src/web/oauth_openid/token.rs +++ b/src/web/oauth_openid/token.rs @@ -1,10 +1,6 @@ //! `/oidc/token` -use std::{ - str::FromStr, - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, -}; +use std::{str::FromStr, sync::Arc}; use axum::{ extract::rejection::FormRejection, @@ -15,13 +11,13 @@ use axum::{ }; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; use blake2::Blake2s256; -use chrono::{Duration, Utc}; +use chrono::Duration; use eyre::{bail, Context}; use josekit::{ jws::{alg::rsassa::RsassaJwsAlgorithm::Rs256, JwsHeader}, jwt::JwtPayload, }; -use rand::{thread_rng, Rng}; +use rand::Rng; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use subtle::ConstantTimeEq; @@ -30,6 +26,7 @@ use tracing::{debug, error}; use crate::{ config::{Configuration, SecretConfig}, store::IdCoopStore, + utils::{Clock, RandGen}, }; use super::ext_codes::{ @@ -53,12 +50,15 @@ pub struct TokenFormParams { /// OpenID Connect clients call this to exchange an authorisation code they received for an access token. /// /// TODO auth_header can be one alternative auth method +#[allow(clippy::too_many_arguments)] pub async fn oidc_token( basic_auth: Option>>, Extension(config): Extension>, Extension(secrets): Extension>, Extension(store): Extension>, Extension(code_store): Extension, + Extension(mut randgen): Extension, + Extension(clock): Extension, form: Result, FormRejection>, ) -> impl IntoResponse { let form = match form { @@ -110,8 +110,9 @@ pub async fn oidc_token( Json(TokenError { code: TokenErrorCode::InvalidClient, description: "That `client_id` is not recognised here.".to_string(), - }) - ).into_response(); + }), + ) + .into_response(); }; if !bool::from( @@ -181,10 +182,10 @@ pub async fn oidc_token( // Create an access token but don't actually issue it yet: // This lets us store the hash of the access token against the redemption of the auth code, // so double redemptions can invalidate the access token appropriately. - let access_token = thread_rng().gen::(); + let access_token = randgen.gen::(); let access_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(access_token); let access_token_hash: AccessTokenHash = Blake2s256::digest(access_token).into(); - let refresh_token = thread_rng().gen::(); + let refresh_token = randgen.gen::(); let refresh_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(refresh_token); let refresh_token_hash: RefreshTokenHash = Blake2s256::digest(refresh_token).into(); @@ -256,7 +257,9 @@ pub async fn oidc_token( } // 2. Check the code challenge - let Some(computed_code_challenge) = compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier) else { + let Some(computed_code_challenge) = + compute_code_challenge(&binding.code_challenge_method, &auth_code_verifier) + else { return ( StatusCode::BAD_REQUEST, Json(TokenError { @@ -287,28 +290,29 @@ pub async fn oidc_token( let user_login_session_id = binding.user_login_session_id; // Issue access token for a new session + let clock2 = clock.clone(); match store .txn(move |mut txn| { Box::pin(async move { let Some(session_id) = txn .create_application_session(user_id, &application_id, user_login_session_id) .await - .context("create_application_session")? else { - return Ok(Err( - ( - StatusCode::BAD_REQUEST, - Json(TokenError { - code: TokenErrorCode::InvalidGrant, - description: "Auth code has expired or was not valid.".to_owned(), - }), - ).into_response() - )); + .context("create_application_session")? + else { + return Ok(Err(( + StatusCode::BAD_REQUEST, + Json(TokenError { + code: TokenErrorCode::InvalidGrant, + description: "Auth code has expired or was not valid.".to_owned(), + }), + ) + .into_response())); }; txn.issue_access_token( &access_token_hash, session_id, // TODO(expiry) Support custom expiry, not 100 years - Utc::now() + Duration::days(365 * 100), + clock2.now_utc() + Duration::days(365 * 100), ) .await .context("issue_access_token")?; @@ -316,7 +320,7 @@ pub async fn oidc_token( &refresh_token_hash, session_id, // TODO(expiry) Support custom expiry, not 100 years - Utc::now() + Duration::days(365 * 100), + clock2.now_utc() + Duration::days(365 * 100), ) .await .context("issue_refresh_token")?; @@ -344,10 +348,7 @@ pub async fn oidc_token( } } - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Before unix epoch?") - .as_secs(); + let now = clock.now_timestamp(); // TODO(expiry) Support custom expiry times (not just 100 years) let exp = now + 100 * 365 * 86400; let sub = binding.user_id.hyphenated().to_string(); @@ -385,7 +386,9 @@ pub async fn oidc_token( } fn make_id_token(id_token: IdToken, secrets: &SecretConfig) -> eyre::Result { - let Ok(serde_json::Value::Object(map)) = serde_json::to_value(id_token).context("failed to serialise ID Token content") else { + let Ok(serde_json::Value::Object(map)) = + serde_json::to_value(id_token).context("failed to serialise ID Token content") + else { bail!("ID Token not a map"); }; @@ -447,9 +450,9 @@ pub struct IdToken { /// Implementers MAY provide for some small leeway, usually no more than a few minutes, to account for clock skew. /// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. /// See RFC 3339 [RFC3339] for details regarding date/times in general and UTC in particular. - pub exp: u64, + pub exp: i64, /// REQUIRED. Time at which the JWT was issued. Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. - pub iat: u64, + pub iat: i64, /// Time when the End-User authentication occurred. /// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. /// When a max_age request is made or when auth_time is requested as an Essential Claim, then this Claim is REQUIRED; otherwise, its inclusion is OPTIONAL. diff --git a/src/web/sessionless_xsrf.rs b/src/web/sessionless_xsrf.rs index c06c4db..2ef3208 100644 --- a/src/web/sessionless_xsrf.rs +++ b/src/web/sessionless_xsrf.rs @@ -7,20 +7,22 @@ //! Even on older browsers, that type of attack is rare, so this 'naïve' scheme is fine for the purpose (rather than having to do anything complicated). use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; -use rand::{thread_rng, Rng}; +use rand::Rng; use subtle::ConstantTimeEq; use time::Duration; use tower_cookies::{Cookie, Cookies}; +use crate::utils::RandGen; + /// Name of the cookie for the sessionless cross-site request forgery prevention cookie. pub const COOKIE_NAME: &str = "__Host-SessionlessXsrf"; /// Gets the Sessionless XSRF token to put into a form request -pub fn get_token(cookies: &Cookies) -> String { +pub fn get_token(cookies: &Cookies, randgen: &mut RandGen) -> String { if let Some(xsrf_cookie) = cookies.get(COOKIE_NAME) { xsrf_cookie.value().to_owned() } else { - let new_token = thread_rng().gen::<[u8; 16]>(); + let new_token = randgen.gen::<[u8; 16]>(); let new_token_b64 = BASE64_URL_SAFE_NO_PAD.encode(new_token); cookies.add( Cookie::build(COOKIE_NAME, new_token_b64.clone()) @@ -37,7 +39,9 @@ pub fn get_token(cookies: &Cookies) -> String { /// Checks a Sessionless XSRF token obtained from a form request pub fn check_token(cookies: &Cookies, token: &str) -> bool { - let Some(xsrf_token) = cookies.get(COOKIE_NAME) else { return false }; + let Some(xsrf_token) = cookies.get(COOKIE_NAME) else { + return false; + }; bool::from(xsrf_token.value().as_bytes().ct_eq(token.as_bytes())) }