diff --git a/Cargo.lock b/Cargo.lock index 9c07f4f..237e42e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -616,6 +616,7 @@ dependencies = [ "chrono", "deadpool-postgres", "dotenvy", + "once_cell", "openssl", "postgres-openssl", "rand", diff --git a/Cargo.toml b/Cargo.toml index 2b902e4..5dbc8b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ tokio = { version = "1.35", features = ["macros", "rt-multi-thread"] } rand = "0.8.5" walkdir = "2.4.0" surrealdb = "1.1.1" +once_cell = "1.19.0" [[bin]] diff --git a/resources/config.toml b/resources/config.toml index 665b2fb..869c26a 100644 --- a/resources/config.toml +++ b/resources/config.toml @@ -5,11 +5,11 @@ token = "" prefix = "!" [persistence] -host = "localhost" -port = 5432 +url = "localhost:8000" user = "postgres" password = "postgres" database = "postgres" +namespace = "truc" tls = false tls_insecure = false diff --git a/resources/test_config.toml b/resources/test_config.toml index 659e21e..b46ee16 100644 --- a/resources/test_config.toml +++ b/resources/test_config.toml @@ -5,11 +5,11 @@ token = "" prefix = "&" [persistence] -host = "localhost" -port = 5433 +url = "localhost" user = "GB8eE8vh" password = "1OLlRZo1tnNluvx" database = "1pNkVsX3FgFeiQdga" +namespace = "truc" tls = false tls_insecure = false diff --git a/src/config.rs b/src/config.rs index b3154d8..ec14386 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,11 +5,11 @@ use std::env; use std::fs::read_to_string; use std::path::PathBuf; -const PERSISTENCE_HOST: &str = "PERSISTENCE_HOST"; -const PERSISTENCE_PORT: &str = "PERSISTENCE_PORT"; -const PERSISTENCE_USER: &str = "PERSISTENCE_USER"; -const PERSISTENCE_PWD: &str = "PERSISTENCE_PWD"; -const PERSISTENCE_DB: &str = "PERSISTENCE_DB"; +const DB_URL: &str = "DB_URL"; +const DB_USER: &str = "DB_USER"; +const DB_PASSWORD: &str = "DB_PASSWORD"; +const DB_NAME: &str = "DB_NAME"; +const DB_NAMESPACE: &str = "DB_NAMESPACE"; const PERSISTENCE_TLS: &str = "PERSISTENCE_TLS"; const PERSISTENCE_TLS_INSECURE: &str = "PERSISTENCE_TLS_INSECURE"; const BOT_NAME: &str = "BOT_NAME"; @@ -47,11 +47,11 @@ pub struct ImageConfig { #[derive(Deserialize, Clone)] pub struct PersistenceConfig { - pub host: String, - pub port: Option, + pub url: String, pub user: String, pub password: String, pub database: String, + pub namespace: String, pub tls: Option, pub tls_insecure: Option, } @@ -97,17 +97,11 @@ fn override_config_with_env_vars(config: Config) -> Config { .parse::() .unwrap(), persistence: PersistenceConfig { - host: env::var(PERSISTENCE_HOST).unwrap_or(pers.host), - port: env::var(PERSISTENCE_PORT) - .map(|p| { - p.parse::() - .expect("Cannot parse the received persistence port") - }) - .ok() - .or(pers.port), - user: env::var(PERSISTENCE_USER).unwrap_or(pers.user), - password: env::var(PERSISTENCE_PWD).unwrap_or(pers.password), - database: env::var(PERSISTENCE_DB).unwrap_or(pers.database), + url: env::var(DB_URL).unwrap_or(pers.url), + user: env::var(DB_USER).unwrap_or(pers.user), + password: env::var(DB_PASSWORD).unwrap_or(pers.password), + database: env::var(DB_NAME).unwrap_or(pers.database), + namespace: env::var(DB_NAMESPACE).unwrap_or(pers.namespace), tls: env::var(PERSISTENCE_TLS) .map(|p| { p.parse::() @@ -145,8 +139,7 @@ mod tests { let config = parse_config(d); let pers = config.persistence; - assert_eq!("localhost", pers.host); - assert_eq!(5433, pers.port.unwrap()); + assert_eq!("localhost:8000", pers.url); assert_eq!("GB8eE8vh", pers.user); assert_eq!("1OLlRZo1tnNluvx", pers.password); assert_eq!("1pNkVsX3FgFeiQdga", pers.database); @@ -157,11 +150,10 @@ mod tests { #[test] #[serial(config)] fn should_override_a_parsed_config_with_env_vars() { - env::set_var(PERSISTENCE_HOST, "my_host"); - env::set_var(PERSISTENCE_PORT, "1111"); - env::set_var(PERSISTENCE_USER, "just_me"); - env::set_var(PERSISTENCE_PWD, "what_a_pwd"); - env::set_var(PERSISTENCE_DB, "my_db"); + env::set_var(DB_URL, "my_host"); + env::set_var(DB_USER, "just_me"); + env::set_var(DB_PASSWORD, "what_a_pwd"); + env::set_var(DB_NAME, "my_db"); env::set_var(PERSISTENCE_TLS, "true"); env::set_var(PERSISTENCE_TLS_INSECURE, "true"); @@ -170,8 +162,7 @@ mod tests { let config = parse_config(d); let pers = config.persistence; - assert_eq!("my_host", pers.host); - assert_eq!(1111, pers.port.unwrap()); + assert_eq!("my_host", pers.url); assert_eq!("just_me", pers.user); assert_eq!("what_a_pwd", pers.password); assert_eq!("my_db", pers.database); @@ -179,11 +170,10 @@ mod tests { assert_eq!(Some(true), pers.tls_insecure); // reset env vars - env::remove_var(PERSISTENCE_HOST); - env::remove_var(PERSISTENCE_PORT); - env::remove_var(PERSISTENCE_USER); - env::remove_var(PERSISTENCE_PWD); - env::remove_var(PERSISTENCE_DB); + env::remove_var(DB_URL); + env::remove_var(DB_USER); + env::remove_var(DB_PASSWORD); + env::remove_var(DB_NAME); env::remove_var(PERSISTENCE_TLS); env::remove_var(PERSISTENCE_TLS_INSECURE); } diff --git a/src/db/init.rs b/src/db/init.rs new file mode 100644 index 0000000..466b288 --- /dev/null +++ b/src/db/init.rs @@ -0,0 +1,36 @@ +use crate::config::PersistenceConfig; +use once_cell::sync::Lazy; +use surrealdb::engine::remote::ws::Ws; +use surrealdb::opt::auth::Root; +use surrealdb::{engine::remote::ws::Client, Surreal}; + +pub static DB: Lazy> = Lazy::new(Surreal::init); + +pub async fn init_db(config: PersistenceConfig) -> Result<(), surrealdb::Error> { + match DB.connect::(config.url.as_str()).await { + Ok(_) => {} + Err(e) => return Err(e), + }; + + match DB + .signin(Root { + username: config.user.as_str(), + password: config.password.as_str(), + }) + .await + { + Ok(_) => {} + Err(e) => return Err(e), + } + + match DB + .use_ns(config.namespace.as_str()) + .use_db(config.database.as_str()) + .await + { + Ok(_) => {} + Err(e) => return Err(e), + } + + Ok(()) +} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..9221b36 --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,2 @@ +pub mod user_image; +pub mod init; \ No newline at end of file diff --git a/src/db/user_image.rs b/src/db/user_image.rs new file mode 100644 index 0000000..d59930a --- /dev/null +++ b/src/db/user_image.rs @@ -0,0 +1,28 @@ +use std::num::NonZeroU64; + +use serde::{Deserialize, Serialize}; +use surrealdb::engine::remote::ws::Ws; +use surrealdb::opt::auth::Root; +use surrealdb::sql::Thing; +use surrealdb::Surreal; + +#[derive(Debug, Serialize)] +pub struct UserImage<'a> { + server_id: &'a NonZeroU64, + user_id: &'a NonZeroU64, + enable: &'a bool, +} + +impl<'a> UserImage<'a> { + pub fn new( + server_id: &'a NonZeroU64, + user_id: &'a NonZeroU64, + enable: &'a bool, + ) -> Result { + Ok(Self { + server_id, + user_id, + enable, + }) + } +} diff --git a/src/main.rs b/src/main.rs index 6edc347..a4a7a8e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ mod bot; mod config; +mod db; mod img; use actix_cors::Cors; @@ -11,6 +12,13 @@ use config::parse_local_config; async fn main() -> std::io::Result<()> { let config = parse_local_config(); let port = config.port; + match db::init::init_db(config.persistence.clone()).await { + Ok(_) => {} + Err(e) => { + println!("Error initializing database: {}", e); + return Ok(()); + } + } start_bot(config.clone()); HttpServer::new(|| { let cors = Cors::default()