Compare commits
3 commits
fa7dc5a3f9
...
ead1c7ebad
Author | SHA1 | Date | |
---|---|---|---|
ead1c7ebad | |||
32d76be0fd | |||
360f3fcbbf |
5 changed files with 93 additions and 29 deletions
|
@ -124,7 +124,10 @@ impl ValidateLogin for LdapBackend {
|
||||||
event!(Level::TRACE, ?search_results, "Got raw search results");
|
event!(Level::TRACE, ?search_results, "Got raw search results");
|
||||||
|
|
||||||
let search_entry = match search_results.len() {
|
let search_entry = match search_results.len() {
|
||||||
1 => SearchEntry::construct(search_results.into_iter().next().unwrap()),
|
1 => {
|
||||||
|
#[allow(clippy::unwrap_used)] // we just checked the length is 1
|
||||||
|
SearchEntry::construct(search_results.into_iter().next().unwrap())
|
||||||
|
}
|
||||||
0 => {
|
0 => {
|
||||||
event!(Level::WARN, "No matching LDAP user found");
|
event!(Level::WARN, "No matching LDAP user found");
|
||||||
return Err(AuthenticationError::InvalidUserOrPassword);
|
return Err(AuthenticationError::InvalidUserOrPassword);
|
||||||
|
|
|
@ -168,6 +168,7 @@ impl ServerPadlockGenerator {
|
||||||
|
|
||||||
#[instrument]
|
#[instrument]
|
||||||
pub fn generate_padlock(&self, server_hash: &ServerHash) -> ServerPadlock {
|
pub fn generate_padlock(&self, server_hash: &ServerHash) -> ServerPadlock {
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
let mut hmac: Hmac<Sha256> = Hmac::new_from_slice(self.secret.0.expose_secret())
|
let mut hmac: Hmac<Sha256> = Hmac::new_from_slice(self.secret.0.expose_secret())
|
||||||
.expect("HMAC should accept key of any length");
|
.expect("HMAC should accept key of any length");
|
||||||
|
|
||||||
|
@ -207,6 +208,7 @@ impl UserServerKeyGenerator {
|
||||||
|
|
||||||
let padlock = self.padlock_generator.generate_padlock(server_hash);
|
let padlock = self.padlock_generator.generate_padlock(server_hash);
|
||||||
|
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
let timestamp = OffsetDateTime::now_utc()
|
let timestamp = OffsetDateTime::now_utc()
|
||||||
.format(format_description!(
|
.format(format_description!(
|
||||||
"[year repr:last_two][month][day][hour repr:24][minute][second]"
|
"[year repr:last_two][month][day][hour repr:24][minute][second]"
|
||||||
|
|
21
src/db.rs
21
src/db.rs
|
@ -43,24 +43,20 @@ pub struct SqliteDatabase {
|
||||||
|
|
||||||
impl SqliteDatabase {
|
impl SqliteDatabase {
|
||||||
#[instrument]
|
#[instrument]
|
||||||
pub async fn open(connection_string: &str) -> Self {
|
pub async fn open(connection_string: &str) -> Result<Self, sqlx::Error> {
|
||||||
let options = SqliteConnectOptions::from_str(connection_string)
|
let options = SqliteConnectOptions::from_str(connection_string)?.create_if_missing(true);
|
||||||
.expect("Invalid database URI")
|
|
||||||
.create_if_missing(true);
|
|
||||||
|
|
||||||
let mut db = Self {
|
let mut db = Self {
|
||||||
conn: SqliteConnection::connect_with(&options)
|
conn: SqliteConnection::connect_with(&options).await?,
|
||||||
.await
|
|
||||||
.expect("Failed to open SQLite database"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
db.init().await;
|
db.init().await?;
|
||||||
|
|
||||||
db
|
Ok(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[instrument]
|
#[instrument]
|
||||||
pub async fn init(&mut self) {
|
async fn init(&mut self) -> Result<(), sqlx::Error> {
|
||||||
query!(
|
query!(
|
||||||
"CREATE TABLE IF NOT EXISTS user_tokens (
|
"CREATE TABLE IF NOT EXISTS user_tokens (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
@ -72,8 +68,9 @@ impl SqliteDatabase {
|
||||||
)"
|
)"
|
||||||
)
|
)
|
||||||
.execute(&mut self.conn)
|
.execute(&mut self.conn)
|
||||||
.await
|
.await?;
|
||||||
.expect("Failed to initialize table user_tokens");
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
62
src/main.rs
62
src/main.rs
|
@ -16,7 +16,12 @@
|
||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#![warn(clippy::pedantic, clippy::as_conversions)]
|
#![warn(
|
||||||
|
clippy::pedantic,
|
||||||
|
clippy::as_conversions,
|
||||||
|
clippy::unwrap_used, // allow case by case, add comment explaining why panic can't happen
|
||||||
|
clippy::expect_used // allow case by case, expect message should be self-explanatory
|
||||||
|
)]
|
||||||
#![forbid(unsafe_code)]
|
#![forbid(unsafe_code)]
|
||||||
|
|
||||||
mod auth;
|
mod auth;
|
||||||
|
@ -25,27 +30,40 @@ mod db;
|
||||||
mod secrets;
|
mod secrets;
|
||||||
mod server;
|
mod server;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::{env, path::PathBuf, sync::Arc};
|
||||||
|
|
||||||
use auth::{
|
use auth::{
|
||||||
AuthenticationBackend, ServerPadlockGenerator, UserAuthenticator, UserServerKeyGenerator,
|
AuthenticationBackend, ServerPadlockGenerator, UserAuthenticator, UserServerKeyGenerator,
|
||||||
};
|
};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use color_eyre::Result;
|
use color_eyre::{eyre::Context, Result};
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use db::{Database, SqliteDatabase};
|
use db::{Database, SqliteDatabase};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tracing::{event, instrument, Level};
|
use tracing::{event, instrument, level_filters::LevelFilter, Level};
|
||||||
use tracing_error::ErrorLayer;
|
use tracing_error::ErrorLayer;
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||||
|
|
||||||
#[instrument]
|
#[instrument]
|
||||||
fn init() -> Result<()> {
|
fn init() -> Result<()> {
|
||||||
|
const FILTER_ENV_VAR: &str = EnvFilter::DEFAULT_ENV;
|
||||||
|
|
||||||
color_eyre::install()?;
|
color_eyre::install()?;
|
||||||
|
|
||||||
let filter_layer = EnvFilter::try_from_default_env()
|
let mut filter_error = None;
|
||||||
.or_else(|_| EnvFilter::try_new("info"))
|
let filter_layer = EnvFilter::builder()
|
||||||
.unwrap();
|
.with_env_var(FILTER_ENV_VAR)
|
||||||
|
.try_from_env()
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
// sure would be nice if the error type was useful
|
||||||
|
if env::var_os(FILTER_ENV_VAR).is_some() {
|
||||||
|
filter_error = Some(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
EnvFilter::builder()
|
||||||
|
.with_default_directive(LevelFilter::INFO.into())
|
||||||
|
.parse_lossy("")
|
||||||
|
});
|
||||||
let fmt_layer = tracing_subscriber::fmt::layer().with_target(true);
|
let fmt_layer = tracing_subscriber::fmt::layer().with_target(true);
|
||||||
|
|
||||||
tracing_subscriber::registry()
|
tracing_subscriber::registry()
|
||||||
|
@ -54,6 +72,14 @@ fn init() -> Result<()> {
|
||||||
.with(ErrorLayer::default())
|
.with(ErrorLayer::default())
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
|
if let Some(e) = filter_error {
|
||||||
|
event!(
|
||||||
|
Level::WARN,
|
||||||
|
error = %e,
|
||||||
|
r#"Tracing filter env variable `{FILTER_ENV_VAR}` contained invalid data, falling back to "info""#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,17 +104,29 @@ struct Args {
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
init()?;
|
init().context("Failed to initialize tracing")?;
|
||||||
|
|
||||||
let config = load_config(&args.config).await?;
|
let config = load_config(&args.config).await.with_context(|| {
|
||||||
|
if let Ok(path) = PathBuf::from(&args.config).canonicalize() {
|
||||||
|
format!("Failed to load config from {path:?}")
|
||||||
|
} else {
|
||||||
|
format!("Failed to load config from invalid path {}", &args.config)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
let database: Arc<Mutex<Box<dyn Database + Send>>> = Arc::new(Mutex::new(Box::new(
|
let database: Arc<Mutex<Box<dyn Database + Send>>> = Arc::new(Mutex::new(Box::new(
|
||||||
SqliteDatabase::open(&config.database.connection_string).await,
|
SqliteDatabase::open(&config.database.connection_string)
|
||||||
|
.await
|
||||||
|
.context("Failed to open database")?,
|
||||||
)));
|
)));
|
||||||
|
|
||||||
let mut auth_backends = vec![];
|
let mut auth_backends = vec![];
|
||||||
for c in config.auth_backends {
|
for (i, c) in config.auth_backends.into_iter().enumerate() {
|
||||||
auth_backends.push(AuthenticationBackend::new(c).await?);
|
auth_backends.push(
|
||||||
|
AuthenticationBackend::new(c)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Failed to initialize backend {i}"))?,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let user_authenticator = Arc::new(UserAuthenticator::new(database, auth_backends));
|
let user_authenticator = Arc::new(UserAuthenticator::new(database, auth_backends));
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
use std::{convert::Infallible, fmt::Debug};
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use hex::FromHex;
|
use hex::FromHex;
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
use secrecy::{ExposeSecret, SecretString, SecretVec};
|
use secrecy::{ExposeSecret, SecretString, SecretVec};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct Password(pub SecretString);
|
pub struct Password(pub SecretString);
|
||||||
|
@ -52,8 +54,18 @@ impl From<String> for ServerPadlock {
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ServerHash(pub String);
|
pub struct ServerHash(pub String);
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct PadlockGenerationSecret(pub SecretVec<u8>);
|
pub struct PadlockGenerationSecret(pub SecretVec<u8>);
|
||||||
|
|
||||||
|
impl PadlockGenerationSecret {
|
||||||
|
/// Entirely arbitrary
|
||||||
|
const MIN_LENGTH_BYTES: usize = 32;
|
||||||
|
|
||||||
|
fn get_random_secret() -> Vec<u8> {
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
(0..Self::MIN_LENGTH_BYTES).map(|_| rng.gen()).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Debug for PadlockGenerationSecret {
|
impl Debug for PadlockGenerationSecret {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_tuple("PadlockGenerationSecret")
|
f.debug_tuple("PadlockGenerationSecret")
|
||||||
|
@ -67,10 +79,22 @@ impl Clone for PadlockGenerationSecret {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Error)]
|
||||||
|
#[error(
|
||||||
|
"Padlock secret too short, must be at least {} bytes - here's a fresh secret for you: {}",
|
||||||
|
PadlockGenerationSecret::MIN_LENGTH_BYTES,
|
||||||
|
hex::encode(PadlockGenerationSecret::get_random_secret())
|
||||||
|
)]
|
||||||
|
pub struct PadlockSecretTooShort;
|
||||||
impl FromHex for PadlockGenerationSecret {
|
impl FromHex for PadlockGenerationSecret {
|
||||||
type Error = Infallible;
|
type Error = PadlockSecretTooShort;
|
||||||
|
|
||||||
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
|
fn from_hex<T: AsRef<[u8]>>(hex: T) -> Result<Self, Self::Error> {
|
||||||
Ok(Self(hex.as_ref().to_vec().into()))
|
let hex = hex.as_ref();
|
||||||
|
if hex.len() < Self::MIN_LENGTH_BYTES {
|
||||||
|
Err(PadlockSecretTooShort)
|
||||||
|
} else {
|
||||||
|
Ok(Self(hex.to_vec().into()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue