Refactors to configuration handling
This commit is contained in:
parent
59a7218005
commit
9c3dfba3ba
115
src/main.rs
115
src/main.rs
@ -1,5 +1,6 @@
|
|||||||
use std::{net::SocketAddr, process, str::FromStr};
|
use std::{net::SocketAddr, process, str::FromStr};
|
||||||
|
|
||||||
|
use anyhow::{bail};
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::State,
|
extract::State,
|
||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode},
|
||||||
@ -22,14 +23,57 @@ use twilight_interactions::command::{CommandInputData, CommandModel, CreateComma
|
|||||||
use twilight_model::{
|
use twilight_model::{
|
||||||
application::interaction::{Interaction, InteractionData, InteractionType},
|
application::interaction::{Interaction, InteractionData, InteractionType},
|
||||||
http::interaction::{InteractionResponse, InteractionResponseType},
|
http::interaction::{InteractionResponse, InteractionResponseType},
|
||||||
id::Id,
|
id::{Id, marker::ApplicationMarker},
|
||||||
};
|
};
|
||||||
mod commands;
|
mod commands;
|
||||||
mod database;
|
mod database;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Config {
|
||||||
|
discord_client_id: Id<ApplicationMarker>,
|
||||||
|
discord_client_secret: String,
|
||||||
|
discord_pub_key: VerifyingKey,
|
||||||
|
database_url: String,
|
||||||
|
listen_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
fn configure() -> anyhow::Result<Self> {
|
||||||
|
dotenvy::dotenv().ok();
|
||||||
|
|
||||||
|
let pub_key_bytes: Vec<u8> = hex::decode(std::env::var("DISCORD_PUB_KEY")?)?;
|
||||||
|
let pub_key: [u8; 32] = match pub_key_bytes.try_into() {
|
||||||
|
Ok(pk) => pk,
|
||||||
|
Err(_) => bail!("Invalid discord public key"),
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
Ok(Config {
|
||||||
|
discord_client_id: Id::from_str(std::env::var("DISCORD_CLIENT_ID")?.as_str())?,
|
||||||
|
discord_client_secret: std::env::var("DISCORD_CLIENT_SECRET")?,
|
||||||
|
discord_pub_key: VerifyingKey::from_bytes(&pub_key)?,
|
||||||
|
database_url: std::env::var("DATABASE_URL")?,
|
||||||
|
listen_port:std::env::var("LISTEN_PORT")?.parse()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authorization(&self) -> String {
|
||||||
|
let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD);
|
||||||
|
let auth = format!("{}:{}", self.discord_client_id, self.discord_client_secret);
|
||||||
|
|
||||||
|
engine.encode(auth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct AppState {
|
||||||
|
config: Config,
|
||||||
|
pg_pool: PgPool,
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
dotenvy::dotenv().ok();
|
let config = Config::configure()?;
|
||||||
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
let tracer = opentelemetry_jaeger::new_agent_pipeline()
|
||||||
.with_service_name("god_replacement_product")
|
.with_service_name("god_replacement_product")
|
||||||
.install_simple()?;
|
.install_simple()?;
|
||||||
@ -52,20 +96,20 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
)
|
)
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
let port = listen_port()?;
|
let port = config.listen_port;
|
||||||
|
|
||||||
let pg_pool = PgPoolOptions::new()
|
let pg_pool = PgPoolOptions::new()
|
||||||
.max_connections(5)
|
.max_connections(5)
|
||||||
.connect(database_url()?.as_str())
|
.connect(&config.database_url)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
sqlx::migrate!().run(&pg_pool).await?;
|
sqlx::migrate!().run(&pg_pool).await?;
|
||||||
|
|
||||||
|
register_commands(config.discord_client_id.to_owned(), config.authorization()).await?;
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/api/discord/interactions/", post(post_interaction))
|
.route("/api/discord/interactions/", post(post_interaction))
|
||||||
.with_state(pg_pool)
|
.with_state(AppState { config, pg_pool })
|
||||||
.layer(TraceLayer::new_for_http());
|
.layer(TraceLayer::new_for_http());
|
||||||
register_commands().await?;
|
|
||||||
|
|
||||||
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
||||||
tracing::debug!("listening on {}", addr);
|
tracing::debug!("listening on {}", addr);
|
||||||
@ -111,12 +155,12 @@ async fn shutdown_signal() {
|
|||||||
|
|
||||||
type InteractionResult = Result<(StatusCode, Json<InteractionResponse>), (StatusCode, String)>;
|
type InteractionResult = Result<(StatusCode, Json<InteractionResponse>), (StatusCode, String)>;
|
||||||
|
|
||||||
fn validate_request(headers: HeaderMap, body: String) -> Result<Interaction, (StatusCode, String)> {
|
fn validate_request(headers: HeaderMap, body: String, pub_key: VerifyingKey) -> Result<Interaction, (StatusCode, String)> {
|
||||||
let Ok(interaction): Result<Interaction, _> = serde_json::from_str(&body) else {
|
let Ok(interaction): Result<Interaction, _> = serde_json::from_str(&body) else {
|
||||||
return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string()))
|
return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string()))
|
||||||
};
|
};
|
||||||
let Some(sig) = headers.get("x-signature-ed25519") else {
|
let Some(sig) = headers.get("x-signature-ed25519") else {
|
||||||
return Err((StatusCode::BAD_REQUEST, "requrest did not include signature header".to_string()))
|
return Err((StatusCode::BAD_REQUEST, "request did not include signature header".to_string()))
|
||||||
};
|
};
|
||||||
let Ok(sig) = hex::decode(sig) else {
|
let Ok(sig) = hex::decode(sig) else {
|
||||||
return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string()))
|
return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string()))
|
||||||
@ -130,7 +174,6 @@ fn validate_request(headers: HeaderMap, body: String) -> Result<Interaction, (St
|
|||||||
let mut signed_buf = signed_buf.as_bytes().to_owned();
|
let mut signed_buf = signed_buf.as_bytes().to_owned();
|
||||||
signed_buf.extend(body.as_bytes());
|
signed_buf.extend(body.as_bytes());
|
||||||
|
|
||||||
let pub_key = discord_pub_key();
|
|
||||||
let Ok(()) = pub_key.verify_strict(&signed_buf, &sig) else {
|
let Ok(()) = pub_key.verify_strict(&signed_buf, &sig) else {
|
||||||
return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string()))
|
return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string()))
|
||||||
};
|
};
|
||||||
@ -138,12 +181,13 @@ fn validate_request(headers: HeaderMap, body: String) -> Result<Interaction, (St
|
|||||||
Ok(interaction)
|
Ok(interaction)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async fn post_interaction(
|
async fn post_interaction(
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
State(pg_pool): State<PgPool>,
|
State(app_state): State<AppState>,
|
||||||
body: String,
|
body: String,
|
||||||
) -> InteractionResult {
|
) -> InteractionResult {
|
||||||
let interaction = match validate_request(headers, body) {
|
let interaction = match validate_request(headers, body, app_state.config.discord_pub_key) {
|
||||||
Ok(interaction) => interaction,
|
Ok(interaction) => interaction,
|
||||||
Err(error) => return Err(error),
|
Err(error) => return Err(error),
|
||||||
};
|
};
|
||||||
@ -177,7 +221,7 @@ async fn post_interaction(
|
|||||||
interaction.channel_id,
|
interaction.channel_id,
|
||||||
author_id,
|
author_id,
|
||||||
command_data,
|
command_data,
|
||||||
&pg_pool,
|
&app_state.pg_pool,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
@ -193,7 +237,7 @@ async fn post_interaction(
|
|||||||
return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME)));
|
return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME)));
|
||||||
};
|
};
|
||||||
|
|
||||||
match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await
|
match get_fact(interaction.channel_id, author_id, command_data, &app_state.pg_pool).await
|
||||||
{
|
{
|
||||||
Ok(response) => Ok((StatusCode::OK, Json(response))),
|
Ok(response) => Ok((StatusCode::OK, Json(response))),
|
||||||
Err(err) => Err(err),
|
Err(err) => Err(err),
|
||||||
@ -216,18 +260,9 @@ fn not_found() -> InteractionResult {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn discord_pub_key_bytes() -> Vec<u8> {
|
async fn register_commands(discord_client_id: Id<ApplicationMarker>, authorization: String) -> anyhow::Result<()> {
|
||||||
hex::decode(std::env::var("DISCORD_PUB_KEY").unwrap()).unwrap()
|
discord_client(authorization)?
|
||||||
}
|
.interaction(discord_client_id)
|
||||||
|
|
||||||
fn discord_pub_key() -> VerifyingKey {
|
|
||||||
let pub_key_bytes: [u8; 32] = discord_pub_key_bytes().try_into().unwrap();
|
|
||||||
VerifyingKey::from_bytes(&pub_key_bytes).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn register_commands() -> anyhow::Result<()> {
|
|
||||||
discord_client()?
|
|
||||||
.interaction(Id::from_str(&discord_client_id()?)?)
|
|
||||||
.set_global_commands(&[
|
.set_global_commands(&[
|
||||||
GetFactCommand::create_command().into(),
|
GetFactCommand::create_command().into(),
|
||||||
SetFactCommand::create_command().into(),
|
SetFactCommand::create_command().into(),
|
||||||
@ -241,15 +276,10 @@ struct ClientCredentialsResponse {
|
|||||||
access_token: String,
|
access_token: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn authorization() -> anyhow::Result<String> {
|
|
||||||
let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD);
|
|
||||||
let auth = format!("{}:{}", discord_client_id()?, discord_client_secret()?);
|
|
||||||
Ok(engine.encode(auth))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn client_credentials_grant() -> anyhow::Result<ClientCredentialsResponse> {
|
fn client_credentials_grant(authorization: String) -> anyhow::Result<ClientCredentialsResponse> {
|
||||||
Ok(ureq::post("https://discord.com/api/v10/oauth2/token")
|
Ok(ureq::post("https://discord.com/api/v10/oauth2/token")
|
||||||
.set("Authorization", &format!("Basic {}", authorization()?))
|
.set("Authorization", &format!("Basic {}", authorization))
|
||||||
.send_form(&[
|
.send_form(&[
|
||||||
("grant_type", "client_credentials"),
|
("grant_type", "client_credentials"),
|
||||||
("scope", "applications.commands.update"),
|
("scope", "applications.commands.update"),
|
||||||
@ -258,25 +288,8 @@ fn client_credentials_grant() -> anyhow::Result<ClientCredentialsResponse> {
|
|||||||
.into_json()?)
|
.into_json()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn discord_client_id() -> anyhow::Result<String> {
|
fn discord_client(authorization: String) -> anyhow::Result<Client> {
|
||||||
std::env::var("DISCORD_CLIENT_ID").map_err(Into::into)
|
let token = client_credentials_grant(authorization)?.access_token;
|
||||||
}
|
|
||||||
|
|
||||||
fn discord_client_secret() -> anyhow::Result<String> {
|
|
||||||
std::env::var("DISCORD_CLIENT_SECRET").map_err(Into::into)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn discord_client() -> anyhow::Result<Client> {
|
|
||||||
let token = client_credentials_grant()?.access_token;
|
|
||||||
Ok(Client::new(format!("Bearer {token}")))
|
Ok(Client::new(format!("Bearer {token}")))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn database_url() -> anyhow::Result<String> {
|
|
||||||
std::env::var("DATABASE_URL").map_err(Into::into)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn listen_port() -> anyhow::Result<u16> {
|
|
||||||
std::env::var("LISTEN_PORT")
|
|
||||||
.map_err(Into::into)
|
|
||||||
.and_then(|v| v.parse::<u16>().map_err(Into::into))
|
|
||||||
}
|
|
||||||
|
Reference in New Issue
Block a user