Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ aws-core = [
# Anything that requires Protocol Buffers.
protobuf-build = ["dep:tonic-build", "dep:prost-build"]

gcp = ["dep:base64", "dep:google-cloud-auth"]
gcp = ["dep:arc-swap", "dep:base64", "dep:google-cloud-auth"]

# Enrichment Tables
enrichment-tables = ["enrichment-tables-geoip", "enrichment-tables-mmdb", "enrichment-tables-memory"]
Expand Down
204 changes: 193 additions & 11 deletions src/gcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::sync::{Arc, LazyLock, Mutex};
use std::time::Duration;

use arc_swap::ArcSwap;
use base64::prelude::{BASE64_URL_SAFE, Engine as _};
use google_cloud_auth::credentials::{AccessTokenCredentials, Builder};
use http::Uri;
Expand All @@ -15,6 +16,8 @@ static ENV_LOCK: Mutex<()> = Mutex::new(());

// See https://cloud.google.com/compute/docs/access/authenticate-workloads#applications
const TOKEN_REFRESH_INTERVAL_SECS: u64 = 3300; // 55 minutes (tokens last 1 hour)
const REBUILD_RETRY_INTERVAL_SECS: u64 = 30;
const MAX_REBUILD_ATTEMPTS: u32 = 3;

pub const PUBSUB_URL: &str = "https://pubsub.googleapis.com";

Expand Down Expand Up @@ -122,13 +125,65 @@ impl GcpAuthConfig {
}
}

#[derive(Clone)]
pub struct CredentialsState {
inner: Arc<ArcSwap<AccessTokenCredentials>>,
path: Option<String>,
scopes: Vec<String>,
}

impl CredentialsState {
fn current_creds(&self) -> Arc<AccessTokenCredentials> {
self.inner.load_full()
}

fn swap_creds(&self, new_creds: Arc<AccessTokenCredentials>) {
self.inner.store(new_creds);
}

async fn rebuild(&self) -> crate::Result<Arc<AccessTokenCredentials>> {
let scopes_vec = self.scopes.clone();
let path = self.path.clone();

tokio::task::spawn_blocking(move || {
let creds = match &path {
Some(path) => {
let _lock = ENV_LOCK.lock().expect("ENV_LOCK poisoned");
let _guard = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", path);
Builder::default()
.with_scopes(scopes_vec)
.build_access_token_credentials()
.context(InvalidCredentialsSnafu)?
}
None => Builder::default()
.with_scopes(scopes_vec)
.build_access_token_credentials()
.context(InvalidCredentialsSnafu)?,
};

Ok(Arc::new(creds))
})
.await
.expect("credential rebuild task panicked")
}
}

#[derive(Clone)]
pub enum GcpAuthenticator {
Credentials(Arc<AccessTokenCredentials>),
Credentials(CredentialsState),
ApiKey(Box<str>),
None,
}

impl std::fmt::Debug for CredentialsState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CredentialsState")
.field("path", &self.path)
.field("scopes", &self.scopes)
.finish_non_exhaustive()
}
}

impl std::fmt::Debug for GcpAuthenticator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -147,34 +202,42 @@ impl GcpAuthenticator {
async fn from_file(path: &str, scopes: &[&str]) -> crate::Result<Self> {
debug!(message = "Loading GCP credentials from file.", path = ?path);

let scopes_vec: Vec<String> = scopes.iter().map(|s| s.to_string()).collect();

// Serialize access: google-cloud-auth Builder has no with_credentials_file()
// method, so we pass the path via the GOOGLE_APPLICATION_CREDENTIALS env var.
// The mutex prevents concurrent from_file() calls from racing on the env var.
let _lock = ENV_LOCK.lock().expect("ENV_LOCK poisoned");

let credentials = {
let _lock = ENV_LOCK.lock().expect("ENV_LOCK poisoned");
let _guard = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", path);
let scopes_vec: Vec<String> = scopes.iter().map(|s| s.to_string()).collect();

Builder::default()
.with_scopes(scopes_vec)
.with_scopes(scopes_vec.clone())
.build_access_token_credentials()
.context(InvalidCredentialsSnafu)?
};

Ok(Self::Credentials(Arc::new(credentials)))
Ok(Self::Credentials(CredentialsState {
inner: Arc::new(ArcSwap::from_pointee(credentials)),
path: Some(path.to_string()),
scopes: scopes_vec,
}))
}

async fn from_adc(scopes: &[&str]) -> crate::Result<Self> {
debug!("Loading GCP credentials using Application Default Credentials (ADC).");

let scopes_vec: Vec<String> = scopes.iter().map(|s| s.to_string()).collect();
let credentials = Builder::default()
.with_scopes(scopes_vec)
.with_scopes(scopes_vec.clone())
.build_access_token_credentials()
.context(InvalidCredentialsSnafu)?;

Ok(Self::Credentials(Arc::new(credentials)))
Ok(Self::Credentials(CredentialsState {
inner: Arc::new(ArcSwap::from_pointee(credentials)),
path: None,
scopes: scopes_vec,
}))
}

fn from_api_key(api_key: &str) -> crate::Result<Self> {
Expand All @@ -186,7 +249,8 @@ impl GcpAuthenticator {

pub fn make_token(&self) -> Option<String> {
match self {
Self::Credentials(creds) => {
Self::Credentials(state) => {
let creds = state.current_creds();
let result = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async { creds.access_token().await })
});
Expand Down Expand Up @@ -242,24 +306,32 @@ impl GcpAuthenticator {

async fn token_regenerator(self, sender: watch::Sender<()>) {
match self {
Self::Credentials(creds) => loop {
Self::Credentials(state) => loop {
let deadline = Duration::from_secs(TOKEN_REFRESH_INTERVAL_SECS);
debug!(
deadline = deadline.as_secs(),
"Sleeping before refreshing GCP authentication token.",
);
tokio::time::sleep(deadline).await;

let creds = state.current_creds();
match creds.access_token().await {
Ok(_) => {
sender.send_replace(());
debug!("GCP authentication token refreshed.");
}
Err(error) => {
error!(
message = "Failed to refresh GCP authentication token.",
message = "Failed to refresh GCP authentication token, attempting credential rebuild.",
%error
);
if Self::try_rebuild_credentials(&state, &sender).await {
debug!("Credential rebuild succeeded, resuming normal refresh cycle.");
} else {
error!(
message = "All GCP credential rebuild attempts failed. Will retry on next refresh cycle.",
);
}
}
}
},
Expand All @@ -271,6 +343,39 @@ impl GcpAuthenticator {
}
}
}

async fn try_rebuild_credentials(state: &CredentialsState, sender: &watch::Sender<()>) -> bool {
for attempt in 1..=MAX_REBUILD_ATTEMPTS {
match state.rebuild().await {
Ok(new_creds) => match new_creds.access_token().await {
Ok(_) => {
state.swap_creds(new_creds);
sender.send_replace(());
info!(message = "GCP credentials rebuilt successfully.", attempt,);
return true;
}
Err(verify_err) => {
warn!(
message = "Rebuilt GCP credentials failed token verification.",
attempt,
error = %verify_err,
);
}
},
Err(rebuild_err) => {
warn!(
message = "Failed to rebuild GCP credentials.",
attempt,
error = %rebuild_err,
);
}
}
if attempt < MAX_REBUILD_ATTEMPTS {
tokio::time::sleep(Duration::from_secs(REBUILD_RETRY_INTERVAL_SECS)).await;
}
}
false
}
}

/// Temporarily set an environment variable, restoring the original value on drop.
Expand Down Expand Up @@ -366,6 +471,83 @@ mod tests {
let _ = build_auth("").await;
}

#[tokio::test]
async fn credentials_state_swap_updates_current() {
let (creds_file, _dir) = write_fake_external_account_json();
let path = creds_file.to_str().unwrap();

let state = build_credentials_state(path);
let first = state.current_creds();

let new_creds = state.rebuild().await.expect("rebuild should succeed");
state.swap_creds(new_creds);

let second = state.current_creds();
assert!(!Arc::ptr_eq(&first, &second));
}

#[tokio::test]
async fn credentials_state_clone_shares_inner() {
let (creds_file, _dir) = write_fake_external_account_json();
let path = creds_file.to_str().unwrap();

let state = build_credentials_state(path);
let cloned = state.clone();

let new_creds = state.rebuild().await.expect("rebuild should succeed");
state.swap_creds(new_creds);

// The clone should see the swapped credentials
let from_original = state.current_creds();
let from_clone = cloned.current_creds();
assert!(Arc::ptr_eq(&from_original, &from_clone));
}

#[tokio::test]
async fn credentials_state_rebuild_fails_with_nonexistent_path() {
let (creds_file, _dir) = write_fake_external_account_json();
let mut state = build_credentials_state(creds_file.to_str().unwrap());
state.path = Some("/nonexistent/path/credentials.json".into());

assert!(state.rebuild().await.is_err());
}

fn write_fake_external_account_json() -> (std::path::PathBuf, tempfile::TempDir) {
let dir = tempfile::tempdir().expect("failed to create temp dir");
let token_file = dir.path().join("token");
std::fs::write(&token_file, "fake-subject-token").expect("write token");

let creds = serde_json::json!({
"type": "external_account",
"audience": "//iam.googleapis.com/projects/123/locations/global/workloadIdentityPools/pool/providers/provider",
"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
"token_url": "https://sts.googleapis.com/v1/token",
"credential_source": {
"file": token_file.to_str().unwrap()
}
});

let creds_file = dir.path().join("external_account.json");
std::fs::write(&creds_file, creds.to_string()).expect("write creds");
(creds_file, dir)
}

fn build_credentials_state(path: &str) -> CredentialsState {
let scopes = vec![scopes::CLOUD_PLATFORM.to_string()];
let _lock = ENV_LOCK.lock().expect("ENV_LOCK poisoned");
let _guard = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", path);
let creds = Builder::default()
.with_scopes(scopes.clone())
.build_access_token_credentials()
.expect("build_access_token_credentials failed");

CredentialsState {
inner: Arc::new(ArcSwap::from_pointee(creds)),
path: Some(path.to_string()),
scopes,
}
}

fn apply_uri(auth: &GcpAuthenticator, uri: &str) -> String {
let mut uri: Uri = uri.parse().unwrap();
auth.apply_uri(&mut uri);
Expand Down