From 9c3dfba3ba3833dc4da9997c8936af2d8d8cfcf5 Mon Sep 17 00:00:00 2001 From: ModZero Date: Tue, 28 Mar 2023 18:44:13 +0200 Subject: [PATCH] Refactors to configuration handling --- src/main.rs | 115 +++++++++++++++++++++++++++++----------------------- 1 file changed, 64 insertions(+), 51 deletions(-) diff --git a/src/main.rs b/src/main.rs index 0d2ecc5..bf7dbad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use std::{net::SocketAddr, process, str::FromStr}; +use anyhow::{bail}; use axum::{ extract::State, http::{HeaderMap, StatusCode}, @@ -22,14 +23,57 @@ use twilight_interactions::command::{CommandInputData, CommandModel, CreateComma use twilight_model::{ application::interaction::{Interaction, InteractionData, InteractionType}, http::interaction::{InteractionResponse, InteractionResponseType}, - id::Id, + id::{Id, marker::ApplicationMarker}, }; mod commands; mod database; +#[derive(Clone)] +struct Config { + discord_client_id: Id, + discord_client_secret: String, + discord_pub_key: VerifyingKey, + database_url: String, + listen_port: u16, +} + +impl Config { + fn configure() -> anyhow::Result { + dotenvy::dotenv().ok(); + + let pub_key_bytes: Vec = hex::decode(std::env::var("DISCORD_PUB_KEY")?)?; + let pub_key: [u8; 32] = match pub_key_bytes.try_into() { + Ok(pk) => pk, + Err(_) => bail!("Invalid discord public key"), + }; + + + Ok(Config { + discord_client_id: Id::from_str(std::env::var("DISCORD_CLIENT_ID")?.as_str())?, + discord_client_secret: std::env::var("DISCORD_CLIENT_SECRET")?, + discord_pub_key: VerifyingKey::from_bytes(&pub_key)?, + database_url: std::env::var("DATABASE_URL")?, + listen_port:std::env::var("LISTEN_PORT")?.parse()?, + }) + } + + fn authorization(&self) -> String { + let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD); + let auth = format!("{}:{}", self.discord_client_id, self.discord_client_secret); + + engine.encode(auth) + } +} + +#[derive(Clone)] +struct AppState { + config: Config, + pg_pool: PgPool, +} + #[tokio::main] async fn main() -> anyhow::Result<()> { - dotenvy::dotenv().ok(); + let config = Config::configure()?; let tracer = opentelemetry_jaeger::new_agent_pipeline() .with_service_name("god_replacement_product") .install_simple()?; @@ -52,20 +96,20 @@ async fn main() -> anyhow::Result<()> { ) .init(); - let port = listen_port()?; + let port = config.listen_port; let pg_pool = PgPoolOptions::new() .max_connections(5) - .connect(database_url()?.as_str()) + .connect(&config.database_url) .await?; sqlx::migrate!().run(&pg_pool).await?; + register_commands(config.discord_client_id.to_owned(), config.authorization()).await?; let app = Router::new() .route("/api/discord/interactions/", post(post_interaction)) - .with_state(pg_pool) + .with_state(AppState { config, pg_pool }) .layer(TraceLayer::new_for_http()); - register_commands().await?; let addr = SocketAddr::from(([127, 0, 0, 1], port)); tracing::debug!("listening on {}", addr); @@ -111,12 +155,12 @@ async fn shutdown_signal() { type InteractionResult = Result<(StatusCode, Json), (StatusCode, String)>; -fn validate_request(headers: HeaderMap, body: String) -> Result { +fn validate_request(headers: HeaderMap, body: String, pub_key: VerifyingKey) -> Result { let Ok(interaction): Result = serde_json::from_str(&body) else { return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string())) }; let Some(sig) = headers.get("x-signature-ed25519") else { - return Err((StatusCode::BAD_REQUEST, "requrest did not include signature header".to_string())) + return Err((StatusCode::BAD_REQUEST, "request did not include signature header".to_string())) }; let Ok(sig) = hex::decode(sig) else { return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string())) @@ -130,7 +174,6 @@ fn validate_request(headers: HeaderMap, body: String) -> Result Result, + State(app_state): State, body: String, ) -> InteractionResult { - let interaction = match validate_request(headers, body) { + let interaction = match validate_request(headers, body, app_state.config.discord_pub_key) { Ok(interaction) => interaction, Err(error) => return Err(error), }; @@ -177,7 +221,7 @@ async fn post_interaction( interaction.channel_id, author_id, command_data, - &pg_pool, + &app_state.pg_pool, ) .await { @@ -193,7 +237,7 @@ async fn post_interaction( return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME))); }; - match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await + match get_fact(interaction.channel_id, author_id, command_data, &app_state.pg_pool).await { Ok(response) => Ok((StatusCode::OK, Json(response))), Err(err) => Err(err), @@ -216,18 +260,9 @@ fn not_found() -> InteractionResult { )) } -fn discord_pub_key_bytes() -> Vec { - hex::decode(std::env::var("DISCORD_PUB_KEY").unwrap()).unwrap() -} - -fn discord_pub_key() -> VerifyingKey { - let pub_key_bytes: [u8; 32] = discord_pub_key_bytes().try_into().unwrap(); - VerifyingKey::from_bytes(&pub_key_bytes).unwrap() -} - -async fn register_commands() -> anyhow::Result<()> { - discord_client()? - .interaction(Id::from_str(&discord_client_id()?)?) +async fn register_commands(discord_client_id: Id, authorization: String) -> anyhow::Result<()> { + discord_client(authorization)? + .interaction(discord_client_id) .set_global_commands(&[ GetFactCommand::create_command().into(), SetFactCommand::create_command().into(), @@ -241,15 +276,10 @@ struct ClientCredentialsResponse { access_token: String, } -fn authorization() -> anyhow::Result { - let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD); - let auth = format!("{}:{}", discord_client_id()?, discord_client_secret()?); - Ok(engine.encode(auth)) -} -fn client_credentials_grant() -> anyhow::Result { +fn client_credentials_grant(authorization: String) -> anyhow::Result { Ok(ureq::post("https://discord.com/api/v10/oauth2/token") - .set("Authorization", &format!("Basic {}", authorization()?)) + .set("Authorization", &format!("Basic {}", authorization)) .send_form(&[ ("grant_type", "client_credentials"), ("scope", "applications.commands.update"), @@ -258,25 +288,8 @@ fn client_credentials_grant() -> anyhow::Result { .into_json()?) } -fn discord_client_id() -> anyhow::Result { - std::env::var("DISCORD_CLIENT_ID").map_err(Into::into) -} - -fn discord_client_secret() -> anyhow::Result { - std::env::var("DISCORD_CLIENT_SECRET").map_err(Into::into) -} - -fn discord_client() -> anyhow::Result { - let token = client_credentials_grant()?.access_token; +fn discord_client(authorization: String) -> anyhow::Result { + let token = client_credentials_grant(authorization)?.access_token; Ok(Client::new(format!("Bearer {token}"))) } -fn database_url() -> anyhow::Result { - std::env::var("DATABASE_URL").map_err(Into::into) -} - -fn listen_port() -> anyhow::Result { - std::env::var("LISTEN_PORT") - .map_err(Into::into) - .and_then(|v| v.parse::().map_err(Into::into)) -}