Lots of refactoring, mostly to figure out how Axum works, tbh
This commit is contained in:
parent
9c3dfba3ba
commit
eab9c83067
17
Cargo.lock
generated
17
Cargo.lock
generated
@ -636,6 +636,7 @@ dependencies = [
|
|||||||
"dotenvy",
|
"dotenvy",
|
||||||
"ed25519-dalek",
|
"ed25519-dalek",
|
||||||
"hex",
|
"hex",
|
||||||
|
"hyper",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-jaeger",
|
"opentelemetry-jaeger",
|
||||||
"serde",
|
"serde",
|
||||||
@ -643,6 +644,7 @@ dependencies = [
|
|||||||
"sqlx",
|
"sqlx",
|
||||||
"time",
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-loki",
|
"tracing-loki",
|
||||||
@ -695,6 +697,16 @@ dependencies = [
|
|||||||
"hashbrown",
|
"hashbrown",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hdrhistogram"
|
||||||
|
version = "7.5.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7f19b9f54f7c7f55e31401bb647626ce0cf0f67b0004982ce815b3ee72a02aa8"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "heck"
|
name = "heck"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
@ -2135,9 +2147,14 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
"hdrhistogram",
|
||||||
|
"indexmap",
|
||||||
"pin-project",
|
"pin-project",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"rand",
|
||||||
|
"slab",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
@ -15,6 +15,7 @@ base64 = "0.21.0"
|
|||||||
dotenvy = "0.15.6"
|
dotenvy = "0.15.6"
|
||||||
ed25519-dalek = "2.0.0-pre.0"
|
ed25519-dalek = "2.0.0-pre.0"
|
||||||
hex = "0.4.3"
|
hex = "0.4.3"
|
||||||
|
hyper = "0.14"
|
||||||
opentelemetry = { version = "0.18.0", features = ["trace", "rt-tokio"] }
|
opentelemetry = { version = "0.18.0", features = ["trace", "rt-tokio"] }
|
||||||
opentelemetry-jaeger = "0.17.0"
|
opentelemetry-jaeger = "0.17.0"
|
||||||
serde = "1.0.152"
|
serde = "1.0.152"
|
||||||
@ -28,6 +29,7 @@ sqlx = { version = "0.6.2", features = [
|
|||||||
]}
|
]}
|
||||||
time = "0.3.20"
|
time = "0.3.20"
|
||||||
tokio = { version = "1.26.0", features = ["full"] }
|
tokio = { version = "1.26.0", features = ["full"] }
|
||||||
|
tower = { version = "0.4.13", features = ["full"] }
|
||||||
tower-http = { version = "0.4.0", features = ["trace"] }
|
tower-http = { version = "0.4.0", features = ["trace"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-loki = "0.2.2"
|
tracing-loki = "0.2.2"
|
||||||
|
@ -2,7 +2,9 @@ use axum::http::StatusCode;
|
|||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use twilight_interactions::command::{CommandModel, CreateCommand};
|
use twilight_interactions::command::{CommandModel, CreateCommand};
|
||||||
use twilight_mention::{Mention, timestamp::{TimestampStyle, Timestamp}};
|
use twilight_mention::{Mention, timestamp::{TimestampStyle, Timestamp}};
|
||||||
use twilight_model::{id::{Id, marker::{InteractionMarker, ChannelMarker, UserMarker}}, http::interaction::{InteractionResponse, InteractionResponseType, InteractionResponseData}, channel::message::MessageFlags};
|
use twilight_model::{id::{Id, marker::{InteractionMarker, ChannelMarker, UserMarker, ApplicationMarker}}, http::interaction::{InteractionResponse, InteractionResponseType, InteractionResponseData}, channel::message::MessageFlags};
|
||||||
|
|
||||||
|
use crate::discord::discord_client;
|
||||||
|
|
||||||
#[derive(CommandModel, CreateCommand)]
|
#[derive(CommandModel, CreateCommand)]
|
||||||
#[command(name = "set_fact", desc = "Quietly save a fact")]
|
#[command(name = "set_fact", desc = "Quietly save a fact")]
|
||||||
@ -94,3 +96,17 @@ pub async fn get_fact(
|
|||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn register_commands(
|
||||||
|
discord_client_id: Id<ApplicationMarker>,
|
||||||
|
authorization: String,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
discord_client(authorization)?
|
||||||
|
.interaction(discord_client_id)
|
||||||
|
.set_global_commands(&[
|
||||||
|
GetFactCommand::create_command().into(),
|
||||||
|
SetFactCommand::create_command().into(),
|
||||||
|
])
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
52
src/discord/interactions.rs
Normal file
52
src/discord/interactions.rs
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
use axum::{
|
||||||
|
async_trait,
|
||||||
|
extract::{FromRef, FromRequest},
|
||||||
|
http::{Request, StatusCode},
|
||||||
|
};
|
||||||
|
use ed25519_dalek::{Signature, VerifyingKey};
|
||||||
|
use hyper::Body;
|
||||||
|
use twilight_model::application::interaction::Interaction;
|
||||||
|
|
||||||
|
pub struct ExtractInteraction(pub Interaction);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S> FromRequest<S, Body> for ExtractInteraction
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
VerifyingKey: FromRef<S>,
|
||||||
|
{
|
||||||
|
type Rejection = (StatusCode, String);
|
||||||
|
|
||||||
|
async fn from_request(request: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let (parts, body) = request.into_parts();
|
||||||
|
let body_bytes = hyper::body::to_bytes(body)
|
||||||
|
.await
|
||||||
|
.map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
|
||||||
|
|
||||||
|
let Ok(interaction): Result<Interaction, _> = serde_json::from_slice(&body_bytes) else {
|
||||||
|
return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string()))
|
||||||
|
};
|
||||||
|
let Some(sig) = parts.headers.get("x-signature-ed25519") else {
|
||||||
|
return Err((StatusCode::BAD_REQUEST, "request did not include signature header".to_string()))
|
||||||
|
};
|
||||||
|
let Ok(sig) = hex::decode(sig) else {
|
||||||
|
return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string()))
|
||||||
|
};
|
||||||
|
let Ok(sig) = Signature::from_slice(&sig) else {
|
||||||
|
return Err((StatusCode::BAD_REQUEST, "request signature is malformed".to_string()))
|
||||||
|
};
|
||||||
|
let Some(signed_buf) = parts.headers.get("x-signature-timestamp") else {
|
||||||
|
return Err((StatusCode::BAD_REQUEST, "requrest did not include signature timestamp header".to_string()))
|
||||||
|
};
|
||||||
|
let mut signed_buf = signed_buf.as_bytes().to_owned();
|
||||||
|
signed_buf.extend(body_bytes);
|
||||||
|
|
||||||
|
let public_key = VerifyingKey::from_ref(state);
|
||||||
|
let Ok(()) = public_key.verify_strict(&signed_buf, &sig) else {
|
||||||
|
return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string()))
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self(interaction))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
26
src/discord/mod.rs
Normal file
26
src/discord/mod.rs
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
use serde::Deserialize;
|
||||||
|
use twilight_http::Client;
|
||||||
|
|
||||||
|
pub mod commands;
|
||||||
|
pub mod interactions;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ClientCredentialsResponse {
|
||||||
|
access_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_credentials_grant(authorization: String) -> anyhow::Result<ClientCredentialsResponse> {
|
||||||
|
Ok(ureq::post("https://discord.com/api/v10/oauth2/token")
|
||||||
|
.set("Authorization", &format!("Basic {}", authorization))
|
||||||
|
.send_form(&[
|
||||||
|
("grant_type", "client_credentials"),
|
||||||
|
("scope", "applications.commands.update"),
|
||||||
|
])
|
||||||
|
.unwrap()
|
||||||
|
.into_json()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn discord_client(authorization: String) -> anyhow::Result<Client> {
|
||||||
|
let token = client_credentials_grant(authorization)?.access_token;
|
||||||
|
Ok(Client::new(format!("Bearer {token}")))
|
||||||
|
}
|
121
src/main.rs
121
src/main.rs
@ -1,32 +1,37 @@
|
|||||||
use std::{net::SocketAddr, process, str::FromStr};
|
use std::{net::SocketAddr, process, str::FromStr};
|
||||||
|
|
||||||
use anyhow::{bail};
|
use anyhow::bail;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::State,
|
extract::{FromRef, State},
|
||||||
http::{HeaderMap, StatusCode},
|
http::StatusCode,
|
||||||
routing::post,
|
routing::post,
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use base64::{alphabet, engine, Engine};
|
use base64::{alphabet, engine, Engine};
|
||||||
use commands::{get_fact, set_fact, GetFactCommand, SetFactCommand};
|
|
||||||
use ed25519_dalek::{Signature, VerifyingKey};
|
use discord::interactions::ExtractInteraction;
|
||||||
use serde::Deserialize;
|
use ed25519_dalek::VerifyingKey;
|
||||||
|
|
||||||
use sqlx::{postgres::PgPoolOptions, PgPool};
|
use sqlx::{postgres::PgPoolOptions, PgPool};
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
|
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
use tracing_subscriber::{
|
use tracing_subscriber::{
|
||||||
filter::LevelFilter, layer::SubscriberExt, util::SubscriberInitExt, Layer,
|
filter::LevelFilter, layer::SubscriberExt, util::SubscriberInitExt, Layer,
|
||||||
};
|
};
|
||||||
use twilight_http::Client;
|
|
||||||
use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand};
|
use twilight_interactions::command::{CommandInputData, CommandModel, CreateCommand};
|
||||||
|
|
||||||
use twilight_model::{
|
use twilight_model::{
|
||||||
application::interaction::{Interaction, InteractionData, InteractionType},
|
application::interaction::{InteractionData, InteractionType},
|
||||||
http::interaction::{InteractionResponse, InteractionResponseType},
|
http::interaction::{InteractionResponse, InteractionResponseType},
|
||||||
id::{Id, marker::ApplicationMarker},
|
id::{marker::ApplicationMarker, Id},
|
||||||
};
|
};
|
||||||
mod commands;
|
|
||||||
|
use crate::discord::commands::{register_commands, SetFactCommand, set_fact, GetFactCommand, get_fact};
|
||||||
|
|
||||||
mod database;
|
mod database;
|
||||||
|
mod discord;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Config {
|
struct Config {
|
||||||
@ -47,13 +52,12 @@ impl Config {
|
|||||||
Err(_) => bail!("Invalid discord public key"),
|
Err(_) => bail!("Invalid discord public key"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
Ok(Config {
|
Ok(Config {
|
||||||
discord_client_id: Id::from_str(std::env::var("DISCORD_CLIENT_ID")?.as_str())?,
|
discord_client_id: Id::from_str(std::env::var("DISCORD_CLIENT_ID")?.as_str())?,
|
||||||
discord_client_secret: std::env::var("DISCORD_CLIENT_SECRET")?,
|
discord_client_secret: std::env::var("DISCORD_CLIENT_SECRET")?,
|
||||||
discord_pub_key: VerifyingKey::from_bytes(&pub_key)?,
|
discord_pub_key: VerifyingKey::from_bytes(&pub_key)?,
|
||||||
database_url: std::env::var("DATABASE_URL")?,
|
database_url: std::env::var("DATABASE_URL")?,
|
||||||
listen_port:std::env::var("LISTEN_PORT")?.parse()?,
|
listen_port: std::env::var("LISTEN_PORT")?.parse()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,6 +75,18 @@ struct AppState {
|
|||||||
pg_pool: PgPool,
|
pg_pool: PgPool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for PgPool {
|
||||||
|
fn from_ref(app_state: &AppState) -> Self {
|
||||||
|
app_state.pg_pool.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for VerifyingKey {
|
||||||
|
fn from_ref(app_state: &AppState) -> Self {
|
||||||
|
app_state.config.discord_pub_key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
let config = Config::configure()?;
|
let config = Config::configure()?;
|
||||||
@ -106,9 +122,11 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
sqlx::migrate!().run(&pg_pool).await?;
|
sqlx::migrate!().run(&pg_pool).await?;
|
||||||
|
|
||||||
register_commands(config.discord_client_id.to_owned(), config.authorization()).await?;
|
register_commands(config.discord_client_id.to_owned(), config.authorization()).await?;
|
||||||
|
let state = AppState { config, pg_pool };
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/api/discord/interactions/", post(post_interaction))
|
.route("/api/discord/interactions/", post(post_interaction))
|
||||||
.with_state(AppState { config, pg_pool })
|
.with_state(state)
|
||||||
.layer(TraceLayer::new_for_http());
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|
||||||
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
||||||
@ -155,43 +173,10 @@ async fn shutdown_signal() {
|
|||||||
|
|
||||||
type InteractionResult = Result<(StatusCode, Json<InteractionResponse>), (StatusCode, String)>;
|
type InteractionResult = Result<(StatusCode, Json<InteractionResponse>), (StatusCode, String)>;
|
||||||
|
|
||||||
fn validate_request(headers: HeaderMap, body: String, pub_key: VerifyingKey) -> Result<Interaction, (StatusCode, String)> {
|
|
||||||
let Ok(interaction): Result<Interaction, _> = serde_json::from_str(&body) else {
|
|
||||||
return Err((StatusCode::BAD_REQUEST, "request contained invalid json".to_string()))
|
|
||||||
};
|
|
||||||
let Some(sig) = headers.get("x-signature-ed25519") else {
|
|
||||||
return Err((StatusCode::BAD_REQUEST, "request did not include signature header".to_string()))
|
|
||||||
};
|
|
||||||
let Ok(sig) = hex::decode(sig) else {
|
|
||||||
return Err((StatusCode::BAD_REQUEST, "requrest signature is invalid hex".to_string()))
|
|
||||||
};
|
|
||||||
let Ok(sig) = Signature::from_slice(&sig) else {
|
|
||||||
return Err((StatusCode::BAD_REQUEST, "request signature is malformed".to_string()))
|
|
||||||
};
|
|
||||||
let Some(signed_buf) = headers.get("x-signature-timestamp") else {
|
|
||||||
return Err((StatusCode::BAD_REQUEST, "requrest did not include signature timestamp header".to_string()))
|
|
||||||
};
|
|
||||||
let mut signed_buf = signed_buf.as_bytes().to_owned();
|
|
||||||
signed_buf.extend(body.as_bytes());
|
|
||||||
|
|
||||||
let Ok(()) = pub_key.verify_strict(&signed_buf, &sig) else {
|
|
||||||
return Err((StatusCode::UNAUTHORIZED, "interaction failed signature verification".to_string()))
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(interaction)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async fn post_interaction(
|
async fn post_interaction(
|
||||||
headers: HeaderMap,
|
State(pg_pool): State<PgPool>,
|
||||||
State(app_state): State<AppState>,
|
ExtractInteraction(interaction): ExtractInteraction,
|
||||||
body: String,
|
|
||||||
) -> InteractionResult {
|
) -> InteractionResult {
|
||||||
let interaction = match validate_request(headers, body, app_state.config.discord_pub_key) {
|
|
||||||
Ok(interaction) => interaction,
|
|
||||||
Err(error) => return Err(error),
|
|
||||||
};
|
|
||||||
|
|
||||||
match interaction.kind {
|
match interaction.kind {
|
||||||
InteractionType::Ping => {
|
InteractionType::Ping => {
|
||||||
let pong = InteractionResponse {
|
let pong = InteractionResponse {
|
||||||
@ -203,7 +188,7 @@ async fn post_interaction(
|
|||||||
}
|
}
|
||||||
InteractionType::ApplicationCommand => {
|
InteractionType::ApplicationCommand => {
|
||||||
let author_id = interaction.author_id();
|
let author_id = interaction.author_id();
|
||||||
let Some(InteractionData::ApplicationCommand(data)) = interaction.data else {
|
let Some(InteractionData::ApplicationCommand(data)) = interaction.data.clone() else {
|
||||||
return not_found();
|
return not_found();
|
||||||
};
|
};
|
||||||
let command_input_data = CommandInputData::from(*data.clone());
|
let command_input_data = CommandInputData::from(*data.clone());
|
||||||
@ -221,7 +206,7 @@ async fn post_interaction(
|
|||||||
interaction.channel_id,
|
interaction.channel_id,
|
||||||
author_id,
|
author_id,
|
||||||
command_data,
|
command_data,
|
||||||
&app_state.pg_pool,
|
&pg_pool,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
@ -237,7 +222,7 @@ async fn post_interaction(
|
|||||||
return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME)));
|
return Err((StatusCode::BAD_REQUEST, format!("{0} requires a user.", GetFactCommand::NAME)));
|
||||||
};
|
};
|
||||||
|
|
||||||
match get_fact(interaction.channel_id, author_id, command_data, &app_state.pg_pool).await
|
match get_fact(interaction.channel_id, author_id, command_data, &pg_pool).await
|
||||||
{
|
{
|
||||||
Ok(response) => Ok((StatusCode::OK, Json(response))),
|
Ok(response) => Ok((StatusCode::OK, Json(response))),
|
||||||
Err(err) => Err(err),
|
Err(err) => Err(err),
|
||||||
@ -259,37 +244,3 @@ fn not_found() -> InteractionResult {
|
|||||||
"requested interaction not found".to_string(),
|
"requested interaction not found".to_string(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn register_commands(discord_client_id: Id<ApplicationMarker>, authorization: String) -> anyhow::Result<()> {
|
|
||||||
discord_client(authorization)?
|
|
||||||
.interaction(discord_client_id)
|
|
||||||
.set_global_commands(&[
|
|
||||||
GetFactCommand::create_command().into(),
|
|
||||||
SetFactCommand::create_command().into(),
|
|
||||||
])
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct ClientCredentialsResponse {
|
|
||||||
access_token: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
fn client_credentials_grant(authorization: String) -> anyhow::Result<ClientCredentialsResponse> {
|
|
||||||
Ok(ureq::post("https://discord.com/api/v10/oauth2/token")
|
|
||||||
.set("Authorization", &format!("Basic {}", authorization))
|
|
||||||
.send_form(&[
|
|
||||||
("grant_type", "client_credentials"),
|
|
||||||
("scope", "applications.commands.update"),
|
|
||||||
])
|
|
||||||
.unwrap()
|
|
||||||
.into_json()?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn discord_client(authorization: String) -> anyhow::Result<Client> {
|
|
||||||
let token = client_credentials_grant(authorization)?.access_token;
|
|
||||||
Ok(Client::new(format!("Bearer {token}")))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user