Compare commits

...

3 commits

Author SHA1 Message Date
Xiretza ead1c7ebad Enforce minimum length for padlock generation secret 2024-02-15 20:58:41 +00:00
Xiretza 32d76be0fd Lint against unwrap/expect 2024-02-15 20:45:34 +00:00
Xiretza 360f3fcbbf Improve error handling 2024-02-15 20:14:07 +00:00
5 changed files with 93 additions and 29 deletions

View file

@ -124,7 +124,10 @@ impl ValidateLogin for LdapBackend {
event!(Level::TRACE, ?search_results, "Got raw search results"); event!(Level::TRACE, ?search_results, "Got raw search results");
let search_entry = match search_results.len() { let search_entry = match search_results.len() {
1 => SearchEntry::construct(search_results.into_iter().next().unwrap()), 1 => {
#[allow(clippy::unwrap_used)] // we just checked the length is 1
SearchEntry::construct(search_results.into_iter().next().unwrap())
}
0 => { 0 => {
event!(Level::WARN, "No matching LDAP user found"); event!(Level::WARN, "No matching LDAP user found");
return Err(AuthenticationError::InvalidUserOrPassword); return Err(AuthenticationError::InvalidUserOrPassword);

View file

@ -168,6 +168,7 @@ impl ServerPadlockGenerator {
#[instrument] #[instrument]
pub fn generate_padlock(&self, server_hash: &ServerHash) -> ServerPadlock { pub fn generate_padlock(&self, server_hash: &ServerHash) -> ServerPadlock {
#[allow(clippy::expect_used)]
let mut hmac: Hmac<Sha256> = Hmac::new_from_slice(self.secret.0.expose_secret()) let mut hmac: Hmac<Sha256> = Hmac::new_from_slice(self.secret.0.expose_secret())
.expect("HMAC should accept key of any length"); .expect("HMAC should accept key of any length");
@ -207,6 +208,7 @@ impl UserServerKeyGenerator {
let padlock = self.padlock_generator.generate_padlock(server_hash); let padlock = self.padlock_generator.generate_padlock(server_hash);
#[allow(clippy::expect_used)]
let timestamp = OffsetDateTime::now_utc() let timestamp = OffsetDateTime::now_utc()
.format(format_description!( .format(format_description!(
"[year repr:last_two][month][day][hour repr:24][minute][second]" "[year repr:last_two][month][day][hour repr:24][minute][second]"

View file

@ -43,24 +43,20 @@ pub struct SqliteDatabase {
impl SqliteDatabase { impl SqliteDatabase {
#[instrument] #[instrument]
pub async fn open(connection_string: &str) -> Self { pub async fn open(connection_string: &str) -> Result<Self, sqlx::Error> {
let options = SqliteConnectOptions::from_str(connection_string) let options = SqliteConnectOptions::from_str(connection_string)?.create_if_missing(true);
.expect("Invalid database URI")
.create_if_missing(true);
let mut db = Self { let mut db = Self {
conn: SqliteConnection::connect_with(&options) conn: SqliteConnection::connect_with(&options).await?,
.await
.expect("Failed to open SQLite database"),
}; };
db.init().await; db.init().await?;
db Ok(db)
} }
#[instrument] #[instrument]
pub async fn init(&mut self) { async fn init(&mut self) -> Result<(), sqlx::Error> {
query!( query!(
"CREATE TABLE IF NOT EXISTS user_tokens ( "CREATE TABLE IF NOT EXISTS user_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@ -72,8 +68,9 @@ impl SqliteDatabase {
)" )"
) )
.execute(&mut self.conn) .execute(&mut self.conn)
.await .await?;
.expect("Failed to initialize table user_tokens");
Ok(())
} }
} }

View file

@ -16,7 +16,12 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
*/ */
#![warn(clippy::pedantic, clippy::as_conversions)] #![warn(
clippy::pedantic,
clippy::as_conversions,
clippy::unwrap_used, // allow case by case, add comment explaining why panic can't happen
clippy::expect_used // allow case by case, expect message should be self-explanatory
)]
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
mod auth; mod auth;
@ -25,27 +30,40 @@ mod db;
mod secrets; mod secrets;
mod server; mod server;
use std::sync::Arc; use std::{env, path::PathBuf, sync::Arc};
use auth::{ use auth::{
AuthenticationBackend, ServerPadlockGenerator, UserAuthenticator, UserServerKeyGenerator, AuthenticationBackend, ServerPadlockGenerator, UserAuthenticator, UserServerKeyGenerator,
}; };
use clap::Parser; use clap::Parser;
use color_eyre::Result; use color_eyre::{eyre::Context, Result};
use config::Config; use config::Config;
use db::{Database, SqliteDatabase}; use db::{Database, SqliteDatabase};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::{event, instrument, Level}; use tracing::{event, instrument, level_filters::LevelFilter, Level};
use tracing_error::ErrorLayer; use tracing_error::ErrorLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
#[instrument] #[instrument]
fn init() -> Result<()> { fn init() -> Result<()> {
const FILTER_ENV_VAR: &str = EnvFilter::DEFAULT_ENV;
color_eyre::install()?; color_eyre::install()?;
let filter_layer = EnvFilter::try_from_default_env() let mut filter_error = None;
.or_else(|_| EnvFilter::try_new("info")) let filter_layer = EnvFilter::builder()
.unwrap(); .with_env_var(FILTER_ENV_VAR)
.try_from_env()
.unwrap_or_else(|e| {
// sure would be nice if the error type was useful
if env::var_os(FILTER_ENV_VAR).is_some() {
filter_error = Some(e);
}
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.parse_lossy("")
});
let fmt_layer = tracing_subscriber::fmt::layer().with_target(true); let fmt_layer = tracing_subscriber::fmt::layer().with_target(true);
tracing_subscriber::registry() tracing_subscriber::registry()
@ -54,6 +72,14 @@ fn init() -> Result<()> {
.with(ErrorLayer::default()) .with(ErrorLayer::default())
.init(); .init();
if let Some(e) = filter_error {
event!(
Level::WARN,
error = %e,
r#"Tracing filter env variable `{FILTER_ENV_VAR}` contained invalid data, falling back to "info""#
);
}
Ok(()) Ok(())
} }
@ -78,17 +104,29 @@ struct Args {
async fn main() -> Result<()> { async fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
init()?; init().context("Failed to initialize tracing")?;
let config = load_config(&args.config).await?; let config = load_config(&args.config).await.with_context(|| {
if let Ok(path) = PathBuf::from(&args.config).canonicalize() {
format!("Failed to load config from {path:?}")
} else {
format!("Failed to load config from invalid path {}", &args.config)
}
})?;
let database: Arc<Mutex<Box<dyn Database + Send>>> = Arc::new(Mutex::new(Box::new( let database: Arc<Mutex<Box<dyn Database + Send>>> = Arc::new(Mutex::new(Box::new(
SqliteDatabase::open(&config.database.connection_string).await, SqliteDatabase::open(&config.database.connection_string)
.await
.context("Failed to open database")?,
))); )));
let mut auth_backends = vec![]; let mut auth_backends = vec![];
for c in config.auth_backends { for (i, c) in config.auth_backends.into_iter().enumerate() {
auth_backends.push(AuthenticationBackend::new(c).await?); auth_backends.push(
AuthenticationBackend::new(c)
.await
.with_context(|| format!("Failed to initialize backend {i}"))?,
);
} }
let user_authenticator = Arc::new(UserAuthenticator::new(database, auth_backends)); let user_authenticator = Arc::new(UserAuthenticator::new(database, auth_backends));

View file

@ -1,8 +1,10 @@
use std::{convert::Infallible, fmt::Debug}; use std::fmt::Debug;
use hex::FromHex; use hex::FromHex;
use rand::{thread_rng, Rng};
use secrecy::{ExposeSecret, SecretString, SecretVec}; use secrecy::{ExposeSecret, SecretString, SecretVec};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct Password(pub SecretString); pub struct Password(pub SecretString);
@ -52,8 +54,18 @@ impl From<String> for ServerPadlock {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerHash(pub String); pub struct ServerHash(pub String);
#[derive(Deserialize)]
pub struct PadlockGenerationSecret(pub SecretVec<u8>); pub struct PadlockGenerationSecret(pub SecretVec<u8>);
impl PadlockGenerationSecret {
/// Entirely arbitrary
const MIN_LENGTH_BYTES: usize = 32;
fn get_random_secret() -> Vec<u8> {
let mut rng = thread_rng();
(0..Self::MIN_LENGTH_BYTES).map(|_| rng.gen()).collect()
}
}
impl Debug for PadlockGenerationSecret { impl Debug for PadlockGenerationSecret {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("PadlockGenerationSecret") f.debug_tuple("PadlockGenerationSecret")
@ -67,10 +79,22 @@ impl Clone for PadlockGenerationSecret {
} }
} }
#[derive(Debug, Clone, Error)]
#[error(
"Padlock secret too short, must be at least {} bytes - here's a fresh secret for you: {}",
PadlockGenerationSecret::MIN_LENGTH_BYTES,
hex::encode(PadlockGenerationSecret::get_random_secret())
)]
pub struct PadlockSecretTooShort;
impl FromHex for PadlockGenerationSecret { impl FromHex for PadlockGenerationSecret {
type Error = Infallible; type Error = PadlockSecretTooShort;
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> { fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
Ok(Self(hex.as_ref().to_vec().into())) let hex = hex.as_ref();
if hex.len() < Self::MIN_LENGTH_BYTES {
Err(PadlockSecretTooShort)
} else {
Ok(Self(hex.to_vec().into()))
}
} }
} }