diff --git a/Cargo.lock b/Cargo.lock index 164b3ec..fdfef9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -636,6 +636,7 @@ dependencies = [ "dotenvy", "ed25519-dalek", "hex", + "hyper", "opentelemetry", "opentelemetry-jaeger", "serde", @@ -643,6 +644,7 @@ dependencies = [ "sqlx", "time", "tokio", + "tower", "tower-http", "tracing", "tracing-loki", @@ -695,6 +697,16 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "hdrhistogram" +version = "7.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f19b9f54f7c7f55e31401bb647626ce0cf0f67b0004982ce815b3ee72a02aa8" +dependencies = [ + "byteorder", + "num-traits", +] + [[package]] name = "heck" version = "0.4.1" @@ -2135,9 +2147,14 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", + "hdrhistogram", + "indexmap", "pin-project", "pin-project-lite", + "rand", + "slab", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 706cba4..72495d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ base64 = "0.21.0" dotenvy = "0.15.6" ed25519-dalek = "2.0.0-pre.0" hex = "0.4.3" +hyper = "0.14" opentelemetry = { version = "0.18.0", features = ["trace", "rt-tokio"] } opentelemetry-jaeger = "0.17.0" serde = "1.0.152" @@ -28,6 +29,7 @@ sqlx = { version = "0.6.2", features = [ ]} time = "0.3.20" tokio = { version = "1.26.0", features = ["full"] } +tower = { version = "0.4.13", features = ["full"] } tower-http = { version = "0.4.0", features = ["trace"] } tracing = "0.1.37" tracing-loki = "0.2.2" diff --git a/src/commands.rs b/src/discord/commands.rs similarity index 85% rename from src/commands.rs rename to src/discord/commands.rs index d72bb0b..6ec4fea 100644 --- a/src/commands.rs +++ b/src/discord/commands.rs @@ -2,7 +2,9 @@ use axum::http::StatusCode; use sqlx::PgPool; use twilight_interactions::command::{CommandModel, CreateCommand}; use twilight_mention::{Mention, timestamp::{TimestampStyle, Timestamp}}; -use twilight_model::{id::{Id, marker::{InteractionMarker, ChannelMarker, UserMarker}}, http::interaction::{InteractionResponse, InteractionResponseType, InteractionResponseData}, channel::message::MessageFlags}; +use twilight_model::{id::{Id, marker::{InteractionMarker, ChannelMarker, UserMarker, ApplicationMarker}}, http::interaction::{InteractionResponse, InteractionResponseType, InteractionResponseData}, channel::message::MessageFlags}; + +use crate::discord::discord_client; #[derive(CommandModel, CreateCommand)] #[command(name = "set_fact", desc = "Quietly save a fact")] @@ -94,3 +96,17 @@ pub async fn get_fact( }), }) } + +pub async fn register_commands( + discord_client_id: Id, + authorization: String, +) -> anyhow::Result<()> { + discord_client(authorization)? + .interaction(discord_client_id) + .set_global_commands(&[ + GetFactCommand::create_command().into(), + SetFactCommand::create_command().into(), + ]) + .await?; + Ok(()) +} diff --git a/src/discord/interactions.rs b/src/discord/interactions.rs new file mode 100644 index 0000000..a2181c6 --- /dev/null +++ b/src/discord/interactions.rs @@ -0,0 +1,52 @@ +use axum::{ + async_trait, + extract::{FromRef, FromRequest}, + http::{Request, StatusCode}, +}; +use ed25519_dalek::{Signature, VerifyingKey}; +use hyper::Body; +use twilight_model::application::interaction::Interaction; + +pub struct ExtractInteraction(pub Interaction); + +#[async_trait] +impl FromRequest for ExtractInteraction +where + S: Send + Sync, + VerifyingKey: FromRef, +{ + type Rejection = (StatusCode, String); + + async fn from_request(request: Request, state: &S) -> Result { + let (parts, body) = request.into_parts(); + let body_bytes = hyper::body::to_bytes(body) + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?; + + let Ok(interaction): Result = serde_json::from_slice(&body_bytes) else { + return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string())) + }; + let Some(sig) = parts.headers.get("x-signature-ed25519") else { + return Err((StatusCode::BAD_REQUEST, "request did not include signature header".to_string())) + }; + let Ok(sig) = hex::decode(sig) else { + return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string())) + }; + let Ok(sig) = Signature::from_slice(&sig) else { + return Err((StatusCode::BAD_REQUEST, "request signature is malformed".to_string())) + }; + let Some(signed_buf) = parts.headers.get("x-signature-timestamp") else { + return Err((StatusCode::BAD_REQUEST, "requrest did not include signature timestamp header".to_string())) + }; + let mut signed_buf = signed_buf.as_bytes().to_owned(); + signed_buf.extend(body_bytes); + + let public_key = VerifyingKey::from_ref(state); + let Ok(()) = public_key.verify_strict(&signed_buf, &sig) else { + return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string())) + }; + + Ok(Self(interaction)) + } +} + diff --git a/src/discord/mod.rs b/src/discord/mod.rs new file mode 100644 index 0000000..2cea92c --- /dev/null +++ b/src/discord/mod.rs @@ -0,0 +1,26 @@ +use serde::Deserialize; +use twilight_http::Client; + +pub mod commands; +pub mod interactions; + +#[derive(Deserialize)] +struct ClientCredentialsResponse { + access_token: String, +} + +fn client_credentials_grant(authorization: String) -> anyhow::Result { + Ok(ureq::post("https://discord.com/api/v10/oauth2/token") + .set("Authorization", &format!("Basic {}", authorization)) + .send_form(&[ + ("grant_type", "client_credentials"), + ("scope", "applications.commands.update"), + ]) + .unwrap() + .into_json()?) +} + +pub fn discord_client(authorization: String) -> anyhow::Result { + let token = client_credentials_grant(authorization)?.access_token; + Ok(Client::new(format!("Bearer {token}"))) +} diff --git a/src/main.rs b/src/main.rs index bf7dbad..443cebd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,32 +1,37 @@ use std::{net::SocketAddr, process, str::FromStr}; -use anyhow::{bail}; +use anyhow::bail; use axum::{ - extract::State, - http::{HeaderMap, StatusCode}, + extract::{FromRef, State}, + http::StatusCode, routing::post, Json, Router, }; use base64::{alphabet, engine, Engine}; -use commands::{get_fact, set_fact, GetFactCommand, SetFactCommand}; -use ed25519_dalek::{Signature, VerifyingKey}; -use serde::Deserialize; + +use discord::interactions::ExtractInteraction; +use ed25519_dalek::VerifyingKey; + use sqlx::{postgres::PgPoolOptions, PgPool}; use tokio::signal; + use tower_http::trace::TraceLayer; use tracing_subscriber::{ filter::LevelFilter, layer::SubscriberExt, util::SubscriberInitExt, Layer, }; -use twilight_http::Client; + use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand}; use twilight_model::{ - application::interaction::{Interaction, InteractionData, InteractionType}, + application::interaction::{InteractionData, InteractionType}, http::interaction::{InteractionResponse, InteractionResponseType}, - id::{Id, marker::ApplicationMarker}, + id::{marker::ApplicationMarker, Id}, }; -mod commands; + +use crate::discord::commands::{register_commands, SetFactCommand, set_fact, GetFactCommand, get_fact}; + mod database; +mod discord; #[derive(Clone)] struct Config { @@ -46,14 +51,13 @@ impl Config { Ok(pk) => pk, Err(_) => bail!("Invalid discord public key"), }; - - + Ok(Config { discord_client_id: Id::from_str(std::env::var("DISCORD_CLIENT_ID")?.as_str())?, discord_client_secret: std::env::var("DISCORD_CLIENT_SECRET")?, discord_pub_key: VerifyingKey::from_bytes(&pub_key)?, database_url: std::env::var("DATABASE_URL")?, - listen_port:std::env::var("LISTEN_PORT")?.parse()?, + listen_port: std::env::var("LISTEN_PORT")?.parse()?, }) } @@ -71,6 +75,18 @@ struct AppState { pg_pool: PgPool, } +impl FromRef for PgPool { + fn from_ref(app_state: &AppState) -> Self { + app_state.pg_pool.clone() + } +} + +impl FromRef for VerifyingKey { + fn from_ref(app_state: &AppState) -> Self { + app_state.config.discord_pub_key + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { let config = Config::configure()?; @@ -106,9 +122,11 @@ async fn main() -> anyhow::Result<()> { sqlx::migrate!().run(&pg_pool).await?; register_commands(config.discord_client_id.to_owned(), config.authorization()).await?; + let state = AppState { config, pg_pool }; + let app = Router::new() .route("/api/discord/interactions/", post(post_interaction)) - .with_state(AppState { config, pg_pool }) + .with_state(state) .layer(TraceLayer::new_for_http()); let addr = SocketAddr::from(([127, 0, 0, 1], port)); @@ -155,43 +173,10 @@ async fn shutdown_signal() { type InteractionResult = Result<(StatusCode, Json), (StatusCode, String)>; -fn validate_request(headers: HeaderMap, body: String, pub_key: VerifyingKey) -> Result { - let Ok(interaction): Result = serde_json::from_str(&body) else { - return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string())) - }; - let Some(sig) = headers.get("x-signature-ed25519") else { - return Err((StatusCode::BAD_REQUEST, "request did not include signature header".to_string())) - }; - let Ok(sig) = hex::decode(sig) else { - return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string())) - }; - let Ok(sig) = Signature::from_slice(&sig) else { - return Err((StatusCode::BAD_REQUEST, "request signature is malformed".to_string())) - }; - let Some(signed_buf) = headers.get("x-signature-timestamp") else { - return Err((StatusCode::BAD_REQUEST, "requrest did not include signature timestamp header".to_string())) - }; - let mut signed_buf = signed_buf.as_bytes().to_owned(); - signed_buf.extend(body.as_bytes()); - - let Ok(()) = pub_key.verify_strict(&signed_buf, &sig) else { - return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string())) - }; - - Ok(interaction) -} - - async fn post_interaction( - headers: HeaderMap, - State(app_state): State, - body: String, + State(pg_pool): State, + ExtractInteraction(interaction): ExtractInteraction, ) -> InteractionResult { - let interaction = match validate_request(headers, body, app_state.config.discord_pub_key) { - Ok(interaction) => interaction, - Err(error) => return Err(error), - }; - match interaction.kind { InteractionType::Ping => { let pong = InteractionResponse { @@ -203,7 +188,7 @@ async fn post_interaction( } InteractionType::ApplicationCommand => { let author_id = interaction.author_id(); - let Some(InteractionData::ApplicationCommand(data)) = interaction.data else { + let Some(InteractionData::ApplicationCommand(data)) = interaction.data.clone() else { return not_found(); }; let command_input_data = CommandInputData::from(*data.clone()); @@ -221,7 +206,7 @@ async fn post_interaction( interaction.channel_id, author_id, command_data, - &app_state.pg_pool, + &pg_pool, ) .await { @@ -237,7 +222,7 @@ async fn post_interaction( return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME))); }; - match get_fact(interaction.channel_id, author_id, command_data, &app_state.pg_pool).await + match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await { Ok(response) => Ok((StatusCode::OK, Json(response))), Err(err) => Err(err), @@ -259,37 +244,3 @@ fn not_found() -> InteractionResult { "requested interaction not found".to_string(), )) } - -async fn register_commands(discord_client_id: Id, authorization: String) -> anyhow::Result<()> { - discord_client(authorization)? - .interaction(discord_client_id) - .set_global_commands(&[ - GetFactCommand::create_command().into(), - SetFactCommand::create_command().into(), - ]) - .await?; - Ok(()) -} - -#[derive(Deserialize)] -struct ClientCredentialsResponse { - access_token: String, -} - - -fn client_credentials_grant(authorization: String) -> anyhow::Result { - Ok(ureq::post("https://discord.com/api/v10/oauth2/token") - .set("Authorization", &format!("Basic {}", authorization)) - .send_form(&[ - ("grant_type", "client_credentials"), - ("scope", "applications.commands.update"), - ]) - .unwrap() - .into_json()?) -} - -fn discord_client(authorization: String) -> anyhow::Result { - let token = client_credentials_grant(authorization)?.access_token; - Ok(Client::new(format!("Bearer {token}"))) -} -