Skip to content

Commit

Permalink
refactor: Improve AuthenticationStorage (#1026)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelzw authored Jan 17, 2025
1 parent d13a186 commit 7591044
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 165 deletions.
7 changes: 2 additions & 5 deletions crates/rattler-bin/src/commands/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use rattler_conda_types::{
Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, ParseStrictness, Platform,
PrefixRecord, RepoDataRecord, Version,
};
use rattler_networking::{AuthenticationMiddleware, AuthenticationStorage};
use rattler_networking::AuthenticationMiddleware;
use rattler_repodata_gateway::{Gateway, RepoData, SourceConfig};
use rattler_solve::{
libsolv_c::{self},
Expand Down Expand Up @@ -147,11 +147,8 @@ pub async fn create(opt: Opt) -> anyhow::Result<()> {
.build()
.expect("failed to create client");

let authentication_storage = AuthenticationStorage::default();
let download_client = reqwest_middleware::ClientBuilder::new(download_client)
.with_arc(Arc::new(AuthenticationMiddleware::new(
authentication_storage,
)))
.with_arc(Arc::new(AuthenticationMiddleware::from_env_and_defaults()?))
.with(rattler_networking::OciMiddleware)
.with(rattler_networking::GCSMiddleware)
.build();
Expand Down
12 changes: 10 additions & 2 deletions crates/rattler/src/cli/auth.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! This module contains CLI common entrypoint for authentication.
use clap::Parser;
use rattler_networking::{Authentication, AuthenticationStorage};
use rattler_networking::{
authentication_storage::backends::file::FileStorageError, Authentication, AuthenticationStorage,
};
use thiserror;

/// Command line arguments that contain authentication data
Expand Down Expand Up @@ -71,6 +73,11 @@ pub enum AuthenticationCLIError {
#[error("Authentication with anaconda.org requires a conda token. Use `--conda-token` to provide one")]
AnacondaOrgBadMethod,

/// Wrapper for errors that are generated from the underlying storage system
/// (keyring or file system)
#[error("Failed to initialize the authentication storage system")]
InitializeStorageError(#[source] FileStorageError),

/// Wrapper for errors that are generated from the underlying storage system
/// (keyring or file system)
#[error("Failed to interact with the authentication storage system")]
Expand Down Expand Up @@ -141,7 +148,8 @@ fn logout(args: LogoutArgs, storage: AuthenticationStorage) -> Result<(), Authen

/// CLI entrypoint for authentication
pub async fn execute(args: Args) -> Result<(), AuthenticationCLIError> {
let storage = AuthenticationStorage::default();
let storage = AuthenticationStorage::from_env_and_defaults()
.map_err(AuthenticationCLIError::InitializeStorageError)?;

match args.subcommand {
Subcommand::Login(args) => login(args, storage),
Expand Down
2 changes: 1 addition & 1 deletion crates/rattler_networking/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ gcs = ["google-cloud-auth", "google-cloud-token"]

[dependencies]
anyhow = { workspace = true }
async-fd-lock = { workspace = true }
async-trait = { workspace = true }
base64 = { workspace = true }
chrono = { workspace = true }
dirs = { workspace = true }
fslock = { workspace = true }
google-cloud-auth = { workspace = true, optional = true }
google-cloud-token = { workspace = true, optional = true }
http = { workspace = true }
Expand Down
38 changes: 24 additions & 14 deletions crates/rattler_networking/src/authentication_middleware.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! `reqwest` middleware that authenticates requests with data from the `AuthenticationStorage`
use crate::authentication_storage::backends::file::FileStorageError;
use crate::{Authentication, AuthenticationStorage};
use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
Expand All @@ -10,7 +11,7 @@ use std::sync::OnceLock;
use url::Url;

/// `reqwest` middleware to authenticate requests
#[derive(Clone, Default)]
#[derive(Clone)]
pub struct AuthenticationMiddleware {
auth_storage: AuthenticationStorage,
}
Expand Down Expand Up @@ -49,10 +50,17 @@ impl Middleware for AuthenticationMiddleware {

impl AuthenticationMiddleware {
/// Create a new authentication middleware with the given authentication storage
pub fn new(auth_storage: AuthenticationStorage) -> Self {
pub fn from_auth_storage(auth_storage: AuthenticationStorage) -> Self {
Self { auth_storage }
}

/// Create a new authentication middleware with the default authentication storage
pub fn from_env_and_defaults() -> Result<Self, FileStorageError> {
Ok(Self {
auth_storage: AuthenticationStorage::from_env_and_defaults()?,
})
}

/// Authenticate the given URL with the given authentication information
fn authenticate_url(url: Url, auth: &Option<Authentication>) -> Url {
if let Some(credentials) = auth {
Expand Down Expand Up @@ -166,7 +174,9 @@ mod tests {
) {
let (captured_tx, captured_rx) = tokio::sync::mpsc::channel(1);
let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::default())
.with_arc(Arc::new(AuthenticationMiddleware::new(storage.clone())))
.with_arc(Arc::new(AuthenticationMiddleware::from_auth_storage(
storage.clone(),
)))
.with_arc(Arc::new(CaptureAbortMiddleware { captured_tx }))
.build();

Expand All @@ -176,8 +186,8 @@ mod tests {
#[test]
fn test_store_fallback() -> anyhow::Result<()> {
let tdir = tempdir()?;
let mut storage = AuthenticationStorage::new();
storage.add_backend(Arc::from(FileStorage::new(
let mut storage = AuthenticationStorage::empty();
storage.add_backend(Arc::from(FileStorage::from_path(
tdir.path().to_path_buf().join("auth.json"),
)?));

Expand All @@ -191,8 +201,8 @@ mod tests {
#[tokio::test]
async fn test_conda_token_storage() -> anyhow::Result<()> {
let tdir = tempdir()?;
let mut storage = AuthenticationStorage::new();
storage.add_backend(Arc::from(FileStorage::new(
let mut storage = AuthenticationStorage::empty();
storage.add_backend(Arc::from(FileStorage::from_path(
tdir.path().to_path_buf().join("auth.json"),
)?));

Expand Down Expand Up @@ -245,8 +255,8 @@ mod tests {
#[tokio::test]
async fn test_bearer_storage() -> anyhow::Result<()> {
let tdir = tempdir()?;
let mut storage = AuthenticationStorage::new();
storage.add_backend(Arc::from(FileStorage::new(
let mut storage = AuthenticationStorage::empty();
storage.add_backend(Arc::from(FileStorage::from_path(
tdir.path().to_path_buf().join("auth.json"),
)?));
let host = "bearer.example.com";
Expand Down Expand Up @@ -305,8 +315,8 @@ mod tests {
#[tokio::test]
async fn test_basic_auth_storage() -> anyhow::Result<()> {
let tdir = tempdir()?;
let mut storage = AuthenticationStorage::new();
storage.add_backend(Arc::from(FileStorage::new(
let mut storage = AuthenticationStorage::empty();
storage.add_backend(Arc::from(FileStorage::from_path(
tdir.path().to_path_buf().join("auth.json"),
)?));
let host = "basic.example.com";
Expand Down Expand Up @@ -383,8 +393,8 @@ mod tests {
("*.com", false),
] {
let tdir = tempdir()?;
let mut storage = AuthenticationStorage::new();
storage.add_backend(Arc::from(FileStorage::new(
let mut storage = AuthenticationStorage::empty();
storage.add_backend(Arc::from(FileStorage::from_path(
tdir.path().to_path_buf().join("auth.json"),
)?));

Expand Down Expand Up @@ -418,7 +428,7 @@ mod tests {
.to_str()
.unwrap(),
),
|| AuthenticationStorage::from_env().unwrap(),
|| AuthenticationStorage::from_env_and_defaults().unwrap(),
);

let host = "test.example.com";
Expand Down
125 changes: 54 additions & 71 deletions crates/rattler_networking/src/authentication_storage/backends/file.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
//! file storage for passwords.
use anyhow::Result;
use fslock::LockFile;
use async_fd_lock::{
blocking::{LockRead, LockWrite},
RwLockWriteGuard,
};
use std::collections::BTreeMap;
use std::fs::File;
use std::path::Path;
use std::sync::Arc;
use std::{path::PathBuf, sync::Mutex};
use std::path::PathBuf;
use std::sync::{Arc, RwLock};

use crate::authentication_storage::StorageBackend;
use crate::Authentication;

#[derive(Clone, Debug)]
struct FileStorageCache {
cache: BTreeMap<String, Authentication>,
content: BTreeMap<String, Authentication>,
file_exists: bool,
}

Expand All @@ -25,7 +29,7 @@ pub struct FileStorage {
/// The cache of the file storage
/// This is used to avoid reading the file from disk every time
/// a credential is accessed
cache: Arc<Mutex<FileStorageCache>>,
cache: Arc<RwLock<FileStorageCache>>,
}

/// An error that can occur when accessing the file storage
Expand All @@ -36,86 +40,80 @@ pub enum FileStorageError {
IOError(#[from] std::io::Error),

/// Failed to lock the file storage file
#[error("failed to lock file storage file {0}")]
FailedToLock(String, #[source] std::io::Error),
#[error("failed to lock file storage file: {0:?}")]
FailedToLock(async_fd_lock::LockError<std::fs::File>),

/// An error occurred when (de)serializing the credentials
#[error("JSON error: {0}")]
JSONError(#[from] serde_json::Error),
}

/// Lock the file storage file for reading and writing. This will block until the lock is
/// acquired.
fn lock_file_storage(path: &Path, write: bool) -> Result<Option<LockFile>, FileStorageError> {
if !write && !path.exists() {
return Ok(None);
}

std::fs::create_dir_all(path.parent().unwrap())?;
let path = path.with_extension("lock");
let mut lock = fslock::LockFile::open(&path)
.map_err(|e| FileStorageError::FailedToLock(path.to_string_lossy().into_owned(), e))?;

// First try to lock the file without block. If we can't immediately get the lock we block and issue a debug message.
if !lock
.try_lock_with_pid()
.map_err(|e| FileStorageError::FailedToLock(path.to_string_lossy().into_owned(), e))?
{
tracing::debug!("waiting for lock on {}", path.to_string_lossy());
lock.lock_with_pid()
.map_err(|e| FileStorageError::FailedToLock(path.to_string_lossy().into_owned(), e))?;
}

Ok(Some(lock))
}

impl FileStorageCache {
pub fn from_path(path: &Path) -> Result<Self, FileStorageError> {
let file_exists = path.exists();
let cache = if file_exists {
lock_file_storage(path, false)?;
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
serde_json::from_reader(reader)?
let content = if file_exists {
let read_guard = File::options()
.read(true)
.open(path)?
.lock_read()
.map_err(FileStorageError::FailedToLock)?;
serde_json::from_reader(read_guard)?
} else {
BTreeMap::new()
};

Ok(Self { cache, file_exists })
Ok(Self {
content,
file_exists,
})
}
}

impl FileStorage {
/// Create a new file storage with the given path
pub fn new(path: PathBuf) -> Result<Self, FileStorageError> {
pub fn from_path(path: PathBuf) -> Result<Self, FileStorageError> {
// read the JSON file if it exists, and store it in the cache
let cache = Arc::new(Mutex::new(FileStorageCache::from_path(&path)?));
let cache = Arc::new(RwLock::new(FileStorageCache::from_path(&path)?));

Ok(Self { path, cache })
}

/// Read the JSON file and deserialize it into a `BTreeMap`, or return an empty `BTreeMap` if the
/// Create a new file storage with the default path
pub fn new() -> Result<Self, FileStorageError> {
let path = dirs::home_dir()
.unwrap()
.join(".rattler")
.join("credentials.json");
Self::from_path(path)
}

/// Updates the cache by reading the JSON file and deserializing it into a `BTreeMap`, or return an empty `BTreeMap` if the
/// file does not exist
fn read_json(&self) -> Result<BTreeMap<String, Authentication>, FileStorageError> {
let new_cache = FileStorageCache::from_path(&self.path)?;
let mut cache = self.cache.lock().unwrap();
cache.cache = new_cache.cache;
let mut cache = self.cache.write().unwrap();
cache.content = new_cache.content;
cache.file_exists = new_cache.file_exists;

Ok(cache.cache.clone())
Ok(cache.content.clone())
}

/// Serialize the given `BTreeMap` and write it to the JSON file
fn write_json(&self, dict: &BTreeMap<String, Authentication>) -> Result<(), FileStorageError> {
let _lock = lock_file_storage(&self.path, true)?;

let file = std::fs::File::create(&self.path)?;
let writer = std::io::BufWriter::new(file);
serde_json::to_writer(writer, dict)?;
let write_guard: std::result::Result<
RwLockWriteGuard<File>,
async_fd_lock::LockError<File>,
> = File::options()
.create(true)
.write(true)
.truncate(true)
.open(&self.path)?
.lock_write();
let write_guard = write_guard.map_err(FileStorageError::FailedToLock)?;
serde_json::to_writer(write_guard, dict)?;

// Store the new data in the cache
let mut cache = self.cache.lock().unwrap();
cache.cache = dict.clone();
let mut cache = self.cache.write().unwrap();
cache.content = dict.clone();
cache.file_exists = true;

Ok(())
Expand All @@ -130,8 +128,8 @@ impl StorageBackend for FileStorage {
}

fn get(&self, host: &str) -> Result<Option<crate::Authentication>> {
let cache = self.cache.lock().unwrap();
Ok(cache.cache.get(host).cloned())
let cache = self.cache.read().unwrap();
Ok(cache.content.get(host).cloned())
}

fn delete(&self, host: &str) -> Result<()> {
Expand All @@ -144,21 +142,6 @@ impl StorageBackend for FileStorage {
}
}

impl Default for FileStorage {
fn default() -> Self {
let mut path = dirs::home_dir().unwrap();
path.push(".rattler");
path.push("credentials.json");
Self::new(path.clone()).unwrap_or(Self {
path,
cache: Arc::new(Mutex::new(FileStorageCache {
cache: BTreeMap::new(),
file_exists: false,
})),
})
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -171,7 +154,7 @@ mod tests {
let file = tempdir().unwrap();
let path = file.path().join("test.json");

let storage = FileStorage::new(path.clone()).unwrap();
let storage = FileStorage::from_path(path.clone()).unwrap();

assert_eq!(storage.get("test").unwrap(), None);

Expand Down Expand Up @@ -207,6 +190,6 @@ mod tests {
let mut file = std::fs::File::create(&path).unwrap();
file.write_all(b"invalid json").unwrap();

assert!(FileStorage::new(path.clone()).is_err());
assert!(FileStorage::from_path(path.clone()).is_err());
}
}
Loading

0 comments on commit 7591044

Please sign in to comment.