use std::{net::SocketAddr, str::FromStr}; use axum::{ extract::State, http::{HeaderMap, StatusCode}, routing::post, Json, Router, }; use base64::{alphabet, engine, Engine}; use commands::{SetFactCommand, GetFactCommand, set_fact, get_fact}; use ed25519_dalek::{Signature, VerifyingKey}; use serde::Deserialize; use sqlx::{postgres::PgPoolOptions, PgPool}; use twilight_http::Client; use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand}; use twilight_model::{ application::interaction::{Interaction, InteractionData, InteractionType}, http::interaction::{InteractionResponse, InteractionResponseType}, id::Id }; mod commands; mod database; #[tokio::main] async fn main() -> anyhow::Result<()> { let port = 4635; dotenvy::dotenv().ok(); let pg_pool = PgPoolOptions::new() .max_connections(5) .connect(database_url().as_str()) .await?; sqlx::migrate!().run(&pg_pool).await?; let app = Router::new() .route("/", post(post_interaction)) .with_state(pg_pool); let addr = SocketAddr::from(([127, 0, 0, 1], port)); register_commands().await; axum::Server::bind(&addr) .serve(app.into_make_service()) .await?; Ok(()) } type InteractionResult = Result<(StatusCode, Json), (StatusCode, String)>; fn validate_request(headers: HeaderMap, body: String) -> Result { let Ok(interaction): Result = serde_json::from_str(&body) else { return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string())) }; let Some(sig) = headers.get("x-signature-ed25519") else { return Err((StatusCode::BAD_REQUEST, "requrest did not include signature header".to_string())) }; let Ok(sig) = hex::decode(sig) else { return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string())) }; let Ok(sig) = Signature::from_slice(&sig) else { return Err((StatusCode::BAD_REQUEST, "request signature is malformed".to_string())) }; let Some(signed_buf) = headers.get("x-signature-timestamp") else { return Err((StatusCode::BAD_REQUEST, "requrest did not include signature timestamp header".to_string())) }; let mut signed_buf = signed_buf.as_bytes().to_owned(); signed_buf.extend(body.as_bytes()); let pub_key = discord_pub_key(); let Ok(()) = pub_key.verify_strict(&signed_buf, &sig) else { return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string())) }; Ok(interaction) } async fn post_interaction( headers: HeaderMap, State(pg_pool): State, body: String, ) -> InteractionResult { let interaction = match validate_request(headers, body) { Ok(interaction) => interaction, Err(error) => return Err(error), }; match interaction.kind { InteractionType::Ping => { let pong = InteractionResponse { kind: InteractionResponseType::Pong, data: None, }; Ok((StatusCode::OK, Json(pong))) } InteractionType::ApplicationCommand => { let author_id = interaction.author_id(); let Some(InteractionData::ApplicationCommand(data)) = interaction.data else { return not_found(); }; let command_input_data = CommandInputData::from(*data.clone()); match &*data.name { SetFactCommand::NAME => { let Ok(command_data) = SetFactCommand::from_interaction(command_input_data) else { return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", SetFactCommand::NAME))); }; let Some(author_id) = author_id else { return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", SetFactCommand::NAME))); }; match set_fact(interaction.id, interaction.channel_id, author_id, command_data, &pg_pool).await { Ok(response) => Ok((StatusCode::OK, Json(response))), Err(err) => Err(err), } }, GetFactCommand::NAME => { let Ok(command_data) = GetFactCommand::from_interaction(command_input_data) else { return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", GetFactCommand::NAME))); }; let Some(author_id) = author_id else { return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME))); }; match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await { Ok(response) => Ok((StatusCode::OK, Json(response))), Err(err) => Err(err), } }, _ => not_found(), } } _ => not_found(), } } fn not_found() -> InteractionResult { Err(( StatusCode::NOT_FOUND, "requested interaction not found".to_string(), )) } fn discord_pub_key_bytes() -> Vec { hex::decode(std::env::var("DISCORD_PUB_KEY").unwrap()).unwrap() } fn discord_pub_key() -> VerifyingKey { let pub_key_bytes: [u8; 32] = discord_pub_key_bytes().try_into().unwrap(); VerifyingKey::from_bytes(&pub_key_bytes).unwrap() } async fn register_commands() { discord_client() .interaction(Id::from_str(&discord_client_id()).unwrap()) .set_global_commands(&[ GetFactCommand::create_command().into(), SetFactCommand::create_command().into(), ]) .await .unwrap(); } #[derive(Deserialize)] struct ClientCredentialsResponse { access_token: String, } fn authorization() -> String { let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD); let auth = format!("{}:{}", discord_client_id(), discord_client_secret(),); engine.encode(auth) } fn client_credentials_grant() -> ClientCredentialsResponse { ureq::post("https://discord.com/api/v10/oauth2/token") .set("Authorization", &format!("Basic {}", authorization())) .send_form(&[ ("grant_type", "client_credentials"), ("scope", "applications.commands.update"), ]) .unwrap() .into_json() .unwrap() } fn discord_client_id() -> String { std::env::var("DISCORD_CLIENT_ID").unwrap() } fn discord_client_secret() -> String { std::env::var("DISCORD_CLIENT_SECRET").unwrap() } fn discord_client() -> Client { let token = client_credentials_grant().access_token; Client::new(format!("Bearer {token}")) } fn database_url() -> String { std::env::var("DATABASE_URL").unwrap() }