diff --git a/Cargo.lock b/Cargo.lock index 707fbd6..a715713 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -576,6 +576,7 @@ dependencies = [ name = "factoriauth" version = "0.1.0" dependencies = [ + "async-trait", "axum", "base64", "clap", diff --git a/Cargo.toml b/Cargo.toml index 97a919f..fa87654 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,3 +39,4 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } rustls = "=0.21.7" rustls-webpki = "=0.101.6" sct = "=0.7.0" +async-trait = "0.1.77" diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 1a83064..befbbb1 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -16,7 +16,7 @@ use tracing::{event, instrument, Level}; use crate::{ config::AuthBackendConfig, - db::{Database, SqliteDatabase, UserTokenEntry}, + db::{Database, UserTokenEntry}, secrets::{ PadlockGenerationSecret, Password, ServerHash, ServerPadlock, UserServerKey, UserToken, }, @@ -69,14 +69,17 @@ impl ValidateLogin for AuthenticationBackend { #[derive(Debug)] pub struct UserAuthenticator { - db: Arc>, + db: Arc>>, backends: Vec, } impl UserAuthenticator { const TOKEN_LEN: usize = 30; - pub fn new(db: Arc>, backends: Vec) -> Self { + pub fn new( + db: Arc>>, + backends: Vec, + ) -> Self { Self { db, backends } } diff --git a/src/db.rs b/src/db.rs index 3816494..8dacbd6 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,5 +1,6 @@ -use std::str::FromStr; +use std::{fmt::Debug, str::FromStr}; +use axum::async_trait; use secrecy::ExposeSecret; use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, SqliteConnection}; use tracing::instrument; @@ -22,7 +23,8 @@ pub enum UserTokenEntry { ), } -pub trait Database { +#[async_trait] +pub trait Database: Debug { async fn get_token(&mut self, username: &str) -> Result, sqlx::Error>; async fn save_token(&mut self, username: &str, token: &UserToken) -> Result<(), sqlx::Error>; @@ -69,6 +71,7 @@ impl SqliteDatabase { } } +#[async_trait] impl Database for SqliteDatabase { #[instrument] async fn get_token(&mut self, username: &str) -> Result, sqlx::Error> { diff --git a/src/main.rs b/src/main.rs index 517f82c..1eaaad5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,7 +33,7 @@ use auth::{ use clap::Parser; use color_eyre::Result; use config::Config; -use db::SqliteDatabase; +use db::{Database, SqliteDatabase}; use tokio::sync::Mutex; use tracing::{event, instrument, Level}; use tracing_error::ErrorLayer; @@ -82,9 +82,9 @@ async fn main() -> Result<()> { let config = load_config(&args.config).await?; - let database = Arc::new(Mutex::new( + let database: Arc>> = Arc::new(Mutex::new(Box::new( SqliteDatabase::open(&config.database.connection_string).await, - )); + ))); let mut auth_backends = vec![]; for c in config.auth_backends {