diff --git a/Cargo.lock b/Cargo.lock index fdfef9c..7a2aa5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -639,6 +639,7 @@ dependencies = [ "hyper", "opentelemetry", "opentelemetry-jaeger", + "rand", "serde", "serde_json", "sqlx", diff --git a/Cargo.toml b/Cargo.toml index 72495d2..e11c51c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,4 +42,5 @@ twilight-model = "0.15.1" twilight-util = { version = "0.15.1", features = ["builder"] } url = "2.3.1" ureq = { version = "2.6.2", features = ["json"] } -uuid = "1.3.0" \ No newline at end of file +uuid = "1.3.0" +rand = "0.8.5" diff --git a/src/discord/commands.rs b/src/discord/commands.rs index 6ec4fea..bcb66a0 100644 --- a/src/discord/commands.rs +++ b/src/discord/commands.rs @@ -1,5 +1,7 @@ use axum::http::StatusCode; +use rand::{thread_rng, Rng}; use sqlx::PgPool; + use twilight_interactions::command::{CommandModel, CreateCommand}; use twilight_mention::{Mention, timestamp::{TimestampStyle, Timestamp}}; use twilight_model::{id::{Id, marker::{InteractionMarker, ChannelMarker, UserMarker, ApplicationMarker}}, http::interaction::{InteractionResponse, InteractionResponseType, InteractionResponseData}, channel::message::MessageFlags}; @@ -15,6 +17,15 @@ pub struct SetFactCommand { fact_value: String, } +#[derive(CommandModel, CreateCommand)] +#[command(name = "roll_fact", desc = "Roll a number between 0 and 99, inclusive, compare to threshold, save success state in a fact")] +pub struct RollFactCommand { + #[command(rename = "name", desc = "Fact name")] + fact_name: String, + #[command(desc = "Difficulty, in percent, 0-100, default 50", max_value=100, min_value=0)] + difficulty: Option, +} + #[derive(CommandModel, CreateCommand)] #[command(name = "get_fact", desc = "Retrieve and display the value of a fact")] pub struct GetFactCommand { @@ -97,6 +108,43 @@ pub async fn get_fact( }) } +pub async fn roll_fact( + interaction_id: Id, + channel_id: Option>, + author_id: Id, + command_data: RollFactCommand, + pg_pool: &PgPool, +) -> Result { + let roll = thread_rng().gen_range(0..100); + let difficulty = command_data.difficulty.unwrap_or(50); + let result = match roll >= difficulty { + true => "success", + false => "failure", + }; + + let Ok(()) = crate::database::set_fact( + pg_pool, + interaction_id.to_string(), + channel_id.map(|cid| cid.to_string()), + author_id.to_string(), + command_data.fact_name.to_owned(), + format!("rolled {0} against {1}, {2}", roll, difficulty, result)).await else { + return Err((StatusCode::INTERNAL_SERVER_ERROR, "Error saving fact.".to_string())); + }; + + Ok(InteractionResponse { + kind: InteractionResponseType::ChannelMessageWithSource, + data: Some(InteractionResponseData { + content: Some(format!( + "Rolled {0} against {1} for {2}, {3}", + roll, difficulty, command_data.fact_name, result + )), + flags: Some(MessageFlags::EPHEMERAL), + ..Default::default() + }), + }) +} + pub async fn register_commands( discord_client_id: Id, authorization: String, @@ -105,6 +153,7 @@ pub async fn register_commands( .interaction(discord_client_id) .set_global_commands(&[ GetFactCommand::create_command().into(), + RollFactCommand::create_command().into(), SetFactCommand::create_command().into(), ]) .await?; diff --git a/src/discord/interactions.rs b/src/discord/interactions.rs index a2181c6..46756db 100644 --- a/src/discord/interactions.rs +++ b/src/discord/interactions.rs @@ -1,11 +1,19 @@ use axum::{ async_trait, - extract::{FromRef, FromRequest}, + extract::{FromRef, FromRequest, State}, http::{Request, StatusCode}, + Json, }; use ed25519_dalek::{Signature, VerifyingKey}; use hyper::Body; -use twilight_model::application::interaction::Interaction; +use sqlx::PgPool; +use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand}; +use twilight_model::{ + application::interaction::{Interaction, InteractionData, InteractionType}, + http::interaction::{InteractionResponse, InteractionResponseType}, +}; + +use crate::discord::commands::{get_fact, set_fact, GetFactCommand, SetFactCommand, RollFactCommand, roll_fact}; pub struct ExtractInteraction(pub Interaction); @@ -50,3 +58,96 @@ where } } +type InteractionResult = Result<(StatusCode, Json), (StatusCode, String)>; + +pub async fn post_interaction( + State(pg_pool): State, + ExtractInteraction(interaction): ExtractInteraction, +) -> InteractionResult { + match interaction.kind { + InteractionType::Ping => { + let pong = InteractionResponse { + kind: InteractionResponseType::Pong, + data: None, + }; + + Ok((StatusCode::OK, Json(pong))) + } + InteractionType::ApplicationCommand => { + let author_id = interaction.author_id(); + let Some(InteractionData::ApplicationCommand(data)) = interaction.data.clone() else { + return not_found(); + }; + let command_input_data = CommandInputData::from(*data.clone()); + tracing::debug!(command_name = ?data.name, "started processing command"); + let result = match &*data.name { + SetFactCommand::NAME => { + let Ok(command_data) = SetFactCommand::from_interaction(command_input_data) else { + return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", SetFactCommand::NAME))); + }; + let Some(author_id) = author_id else { + return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", SetFactCommand::NAME))); + }; + match set_fact( + interaction.id, + interaction.channel_id, + author_id, + command_data, + &pg_pool, + ) + .await + { + Ok(response) => Ok((StatusCode::OK, Json(response))), + Err(err) => Err(err), + } + } + GetFactCommand::NAME => { + let Ok(command_data) = GetFactCommand::from_interaction(command_input_data) else { + return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", GetFactCommand::NAME))); + }; + let Some(author_id) = author_id else { + return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME))); + }; + + match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await + { + Ok(response) => Ok((StatusCode::OK, Json(response))), + Err(err) => Err(err), + } + } + RollFactCommand::NAME => { + let Ok(command_data) = RollFactCommand::from_interaction(command_input_data) else { + return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", SetFactCommand::NAME))); + }; + let Some(author_id) = author_id else { + return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", SetFactCommand::NAME))); + }; + match roll_fact( + interaction.id, + interaction.channel_id, + author_id, + command_data, + &pg_pool, + ) + .await + { + Ok(response) => Ok((StatusCode::OK, Json(response))), + Err(err) => Err(err), + } + } + _ => not_found() + }; + + tracing::debug!(command_name = ?data.name, "finished processing command"); + result + } + _ => not_found(), + } +} + +fn not_found() -> InteractionResult { + Err(( + StatusCode::NOT_FOUND, + "requested interaction not found".to_string(), + )) +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..33d6378 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,65 @@ +use std::str::FromStr; + +use anyhow::bail; +use axum::extract::FromRef; +use base64::{engine, alphabet, Engine}; +use ed25519_dalek::VerifyingKey; +use sqlx::PgPool; +use twilight_model::id::{marker::ApplicationMarker, Id}; + +pub mod database; +pub mod discord; + +#[derive(Clone)] +pub struct Config { + pub discord_client_id: Id, + discord_client_secret: String, + discord_pub_key: VerifyingKey, + pub database_url: String, + pub listen_port: u16, +} + +impl Config { + pub fn configure() -> anyhow::Result { + dotenvy::dotenv().ok(); + + let pub_key_bytes: Vec = hex::decode(std::env::var("DISCORD_PUB_KEY")?)?; + let pub_key: [u8; 32] = match pub_key_bytes.try_into() { + Ok(pk) => pk, + Err(_) => bail!("Invalid discord public key"), + }; + + Ok(Config { + discord_client_id: Id::from_str(std::env::var("DISCORD_CLIENT_ID")?.as_str())?, + discord_client_secret: std::env::var("DISCORD_CLIENT_SECRET")?, + discord_pub_key: VerifyingKey::from_bytes(&pub_key)?, + database_url: std::env::var("DATABASE_URL")?, + listen_port: std::env::var("LISTEN_PORT")?.parse()?, + }) + } + + pub fn authorization(&self) -> String { + let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD); + let auth = format!("{}:{}", self.discord_client_id, self.discord_client_secret); + + engine.encode(auth) + } +} + +#[derive(Clone)] +pub struct AppState { + pub config: Config, + pub pg_pool: PgPool, +} + +impl FromRef for PgPool { + fn from_ref(app_state: &AppState) -> Self { + app_state.pg_pool.clone() + } +} + +impl FromRef for VerifyingKey { + fn from_ref(app_state: &AppState) -> Self { + app_state.config.discord_pub_key + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 443cebd..f407eb1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,8 @@ -use std::{net::SocketAddr, process, str::FromStr}; +use std::{net::SocketAddr, process}; -use anyhow::bail; -use axum::{ - extract::{FromRef, State}, - http::StatusCode, - routing::post, - Json, Router, -}; -use base64::{alphabet, engine, Engine}; +use axum::{routing::post, Router}; -use discord::interactions::ExtractInteraction; -use ed25519_dalek::VerifyingKey; - -use sqlx::{postgres::PgPoolOptions, PgPool}; +use sqlx::postgres::PgPoolOptions; use tokio::signal; use tower_http::trace::TraceLayer; @@ -20,73 +10,11 @@ use tracing_subscriber::{ filter::LevelFilter, layer::SubscriberExt, util::SubscriberInitExt, Layer, }; -use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand}; - -use twilight_model::{ - application::interaction::{InteractionData, InteractionType}, - http::interaction::{InteractionResponse, InteractionResponseType}, - id::{marker::ApplicationMarker, Id}, +use god_replacement_product::{ + discord::{commands::register_commands, interactions::post_interaction}, + AppState, Config, }; -use crate::discord::commands::{register_commands, SetFactCommand, set_fact, GetFactCommand, get_fact}; - -mod database; -mod discord; - -#[derive(Clone)] -struct Config { - discord_client_id: Id, - discord_client_secret: String, - discord_pub_key: VerifyingKey, - database_url: String, - listen_port: u16, -} - -impl Config { - fn configure() -> anyhow::Result { - dotenvy::dotenv().ok(); - - let pub_key_bytes: Vec = hex::decode(std::env::var("DISCORD_PUB_KEY")?)?; - let pub_key: [u8; 32] = match pub_key_bytes.try_into() { - Ok(pk) => pk, - Err(_) => bail!("Invalid discord public key"), - }; - - Ok(Config { - discord_client_id: Id::from_str(std::env::var("DISCORD_CLIENT_ID")?.as_str())?, - discord_client_secret: std::env::var("DISCORD_CLIENT_SECRET")?, - discord_pub_key: VerifyingKey::from_bytes(&pub_key)?, - database_url: std::env::var("DATABASE_URL")?, - listen_port: std::env::var("LISTEN_PORT")?.parse()?, - }) - } - - fn authorization(&self) -> String { - let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD); - let auth = format!("{}:{}", self.discord_client_id, self.discord_client_secret); - - engine.encode(auth) - } -} - -#[derive(Clone)] -struct AppState { - config: Config, - pg_pool: PgPool, -} - -impl FromRef for PgPool { - fn from_ref(app_state: &AppState) -> Self { - app_state.pg_pool.clone() - } -} - -impl FromRef for VerifyingKey { - fn from_ref(app_state: &AppState) -> Self { - app_state.config.discord_pub_key - } -} - #[tokio::main] async fn main() -> anyhow::Result<()> { let config = Config::configure()?; @@ -170,77 +98,3 @@ async fn shutdown_signal() { tracing::info!("signal received, starting graceful shutdown"); } - -type InteractionResult = Result<(StatusCode, Json), (StatusCode, String)>; - -async fn post_interaction( - State(pg_pool): State, - ExtractInteraction(interaction): ExtractInteraction, -) -> InteractionResult { - match interaction.kind { - InteractionType::Ping => { - let pong = InteractionResponse { - kind: InteractionResponseType::Pong, - data: None, - }; - - Ok((StatusCode::OK, Json(pong))) - } - InteractionType::ApplicationCommand => { - let author_id = interaction.author_id(); - let Some(InteractionData::ApplicationCommand(data)) = interaction.data.clone() else { - return not_found(); - }; - let command_input_data = CommandInputData::from(*data.clone()); - tracing::debug!(command_name = ?data.name, "started processing command"); - let result = match &*data.name { - SetFactCommand::NAME => { - let Ok(command_data) = SetFactCommand::from_interaction(command_input_data) else { - return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", SetFactCommand::NAME))); - }; - let Some(author_id) = author_id else { - return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", SetFactCommand::NAME))); - }; - match set_fact( - interaction.id, - interaction.channel_id, - author_id, - command_data, - &pg_pool, - ) - .await - { - Ok(response) => Ok((StatusCode::OK, Json(response))), - Err(err) => Err(err), - } - } - GetFactCommand::NAME => { - let Ok(command_data) = GetFactCommand::from_interaction(command_input_data) else { - return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", GetFactCommand::NAME))); - }; - let Some(author_id) = author_id else { - return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME))); - }; - - match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await - { - Ok(response) => Ok((StatusCode::OK, Json(response))), - Err(err) => Err(err), - } - } - _ => not_found(), - }; - - tracing::debug!(command_name = ?data.name, "finished processing command"); - result - } - _ => not_found(), - } -} - -fn not_found() -> InteractionResult { - Err(( - StatusCode::NOT_FOUND, - "requested interaction not found".to_string(), - )) -}