This repository has been archived on 2023-10-11. You can view files and clone it, but cannot push or open issues or pull requests.

204 lines
6.8 KiB
Rust

use std::{net::SocketAddr, str::FromStr};
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
routing::post,
Json, Router,
};
use base64::{alphabet, engine, Engine};
use commands::{SetFactCommand, GetFactCommand, set_fact, get_fact};
use ed25519_dalek::{Signature, VerifyingKey};
use serde::Deserialize;
use sqlx::{postgres::PgPoolOptions, PgPool};
use twilight_http::Client;
use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand};
use twilight_model::{
application::interaction::{Interaction, InteractionData, InteractionType},
http::interaction::{InteractionResponse, InteractionResponseType},
id::Id
};
mod commands;
mod database;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let port = 4635;
dotenvy::dotenv().ok();
let pg_pool = PgPoolOptions::new()
.max_connections(5)
.connect(database_url().as_str())
.await?;
sqlx::migrate!().run(&pg_pool).await?;
let app = Router::new()
.route("/", post(post_interaction))
.with_state(pg_pool);
let addr = SocketAddr::from(([127, 0, 0, 1], port));
register_commands().await;
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await?;
Ok(())
}
type InteractionResult = Result<(StatusCode, Json<InteractionResponse>), (StatusCode, String)>;
fn validate_request(headers: HeaderMap, body: String) -> Result<Interaction, (StatusCode, String)> {
let Ok(interaction): Result<Interaction, _> = serde_json::from_str(&body) else {
return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string()))
};
let Some(sig) = headers.get("x-signature-ed25519") else {
return Err((StatusCode::BAD_REQUEST, "requrest did not include signature header".to_string()))
};
let Ok(sig) = hex::decode(sig) else {
return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string()))
};
let Ok(sig) = Signature::from_slice(&sig) else {
return Err((StatusCode::BAD_REQUEST, "request signature is malformed".to_string()))
};
let Some(signed_buf) = headers.get("x-signature-timestamp") else {
return Err((StatusCode::BAD_REQUEST, "requrest did not include signature timestamp header".to_string()))
};
let mut signed_buf = signed_buf.as_bytes().to_owned();
signed_buf.extend(body.as_bytes());
let pub_key = discord_pub_key();
let Ok(()) = pub_key.verify_strict(&signed_buf, &sig) else {
return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string()))
};
Ok(interaction)
}
async fn post_interaction(
headers: HeaderMap,
State(pg_pool): State<PgPool>,
body: String,
) -> InteractionResult {
let interaction = match validate_request(headers, body) {
Ok(interaction) => interaction,
Err(error) => return Err(error),
};
match interaction.kind {
InteractionType::Ping => {
let pong = InteractionResponse {
kind: InteractionResponseType::Pong,
data: None,
};
Ok((StatusCode::OK, Json(pong)))
}
InteractionType::ApplicationCommand => {
let author_id = interaction.author_id();
let Some(InteractionData::ApplicationCommand(data)) = interaction.data else {
return not_found();
};
let command_input_data = CommandInputData::from(*data.clone());
match &*data.name {
SetFactCommand::NAME => {
let Ok(command_data) = SetFactCommand::from_interaction(command_input_data) else {
return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", SetFactCommand::NAME)));
};
let Some(author_id) = author_id else {
return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", SetFactCommand::NAME)));
};
match set_fact(interaction.id, interaction.channel_id, author_id, command_data, &pg_pool).await {
Ok(response) => Ok((StatusCode::OK, Json(response))),
Err(err) => Err(err),
}
},
GetFactCommand::NAME => {
let Ok(command_data) = GetFactCommand::from_interaction(command_input_data) else {
return Err((StatusCode::BAD_REQUEST, format!("invalid {0} command.", GetFactCommand::NAME)));
};
let Some(author_id) = author_id else {
return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME)));
};
match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await {
Ok(response) => Ok((StatusCode::OK, Json(response))),
Err(err) => Err(err),
}
},
_ => not_found(),
}
}
_ => not_found(),
}
}
fn not_found() -> InteractionResult {
Err((
StatusCode::NOT_FOUND,
"requested interaction not found".to_string(),
))
}
fn discord_pub_key_bytes() -> Vec<u8> {
hex::decode(std::env::var("DISCORD_PUB_KEY").unwrap()).unwrap()
}
fn discord_pub_key() -> VerifyingKey {
let pub_key_bytes: [u8; 32] = discord_pub_key_bytes().try_into().unwrap();
VerifyingKey::from_bytes(&pub_key_bytes).unwrap()
}
async fn register_commands() {
discord_client()
.interaction(Id::from_str(&discord_client_id()).unwrap())
.set_global_commands(&[
GetFactCommand::create_command().into(),
SetFactCommand::create_command().into(),
])
.await
.unwrap();
}
#[derive(Deserialize)]
struct ClientCredentialsResponse {
access_token: String,
}
fn authorization() -> String {
let engine = engine::GeneralPurpose::new(&alphabet::STANDARD, engine::general_purpose::PAD);
let auth = format!("{}:{}", discord_client_id(), discord_client_secret(),);
engine.encode(auth)
}
fn client_credentials_grant() -> ClientCredentialsResponse {
ureq::post("https://discord.com/api/v10/oauth2/token")
.set("Authorization", &format!("Basic {}", authorization()))
.send_form(&[
("grant_type", "client_credentials"),
("scope", "applications.commands.update"),
])
.unwrap()
.into_json()
.unwrap()
}
fn discord_client_id() -> String {
std::env::var("DISCORD_CLIENT_ID").unwrap()
}
fn discord_client_secret() -> String {
std::env::var("DISCORD_CLIENT_SECRET").unwrap()
}
fn discord_client() -> Client {
let token = client_credentials_grant().access_token;
Client::new(format!("Bearer {token}"))
}
fn database_url() -> String {
std::env::var("DATABASE_URL").unwrap()
}