From 867de13a733dcd4940e5c63149ec0b6c1ad1df1b Mon Sep 17 00:00:00 2001 From: Sahil Gupte Date: Sat, 21 Dec 2024 19:43:42 -0500 Subject: [PATCH] Autogenerate tauri invokes --- Cargo.lock | 18 ++ Cargo.toml | 2 + build.rs | 14 + src-tauri/Cargo.toml | 2 + src-tauri/extensions/src/lib.rs | 7 +- src-tauri/macros/src/command_macro.rs | 4 + src-tauri/src/db/mod.rs | 4 +- src-tauri/src/providers/extension.rs | 8 +- src-tauri/src/providers/spotify.rs | 8 +- src-tauri/src/scanner/mod.rs | 8 +- src-tauri/tauri-invoke-proc/Cargo.toml | 19 ++ src-tauri/tauri-invoke-proc/build.rs | 11 + src-tauri/tauri-invoke-proc/src/common.rs | 15 + src-tauri/tauri-invoke-proc/src/core.rs | 292 ++++++++++++++++++ src-tauri/tauri-invoke-proc/src/lib.rs | 15 + src-tauri/tauri-invoke-proc/src/ui.rs | 352 ++++++++++++++++++++++ src/app.rs | 10 +- src/components/prefs/components.rs | 15 +- src/components/songlist.rs | 2 +- src/modals/discover_extensions.rs | 11 +- src/pages/albums.rs | 2 +- src/players/librespot.rs | 32 +- src/players/rodio.rs | 11 +- src/store/player_store.rs | 4 +- src/store/provider_store.rs | 22 +- src/utils/context_menu.rs | 2 +- src/utils/db_utils.rs | 171 +++-------- src/utils/invoke.rs | 1 + src/utils/mod.rs | 1 + src/utils/prefs.rs | 99 +----- 30 files changed, 861 insertions(+), 301 deletions(-) create mode 100644 build.rs create mode 100644 src-tauri/tauri-invoke-proc/Cargo.toml create mode 100644 src-tauri/tauri-invoke-proc/build.rs create mode 100644 src-tauri/tauri-invoke-proc/src/common.rs create mode 100644 src-tauri/tauri-invoke-proc/src/core.rs create mode 100644 src-tauri/tauri-invoke-proc/src/lib.rs create mode 100644 src-tauri/tauri-invoke-proc/src/ui.rs create mode 100644 src/utils/invoke.rs diff --git a/Cargo.lock b/Cargo.lock index 3ace567f..4ee39878 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,6 +28,7 @@ dependencies = [ "serde", "serde-wasm-bindgen", "serde_json", + "tauri-invoke-proc", "tokio", "tracing", "tracing-subscriber", @@ -5311,6 +5312,7 @@ dependencies = [ "serde_json", "tauri", "tauri-build", + "tauri-invoke-proc", "tauri-plugin-autostart", "tauri-plugin-deep-link", "tauri-plugin-dialog", @@ -9006,6 +9008,22 @@ dependencies = [ "walkdir", ] +[[package]] +name = "tauri-invoke-proc" +version = "0.1.0" +dependencies = [ + "ctor", + "lazy_static", + "once_cell", + "proc-macro2", + "quote", + "regex", + "serde", + "serde_json", + "syn 2.0.90", + "tracing", +] + [[package]] name = "tauri-macros" version = "2.0.3" diff --git a/Cargo.toml b/Cargo.toml index 9717873c..defd7249 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ indexed_db_futures = "0.5.0" lazy_static = "1.5.0" tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } +tauri-invoke-proc = { path = "src-tauri/tauri-invoke-proc" } [workspace] resolver = "2" @@ -64,6 +65,7 @@ members = [ "src-tauri/youtube", "src-tauri/rodio-player", "src/pref_gen", + "src-tauri/tauri-invoke-proc", ] [package.metadata.leptos-i18n] diff --git a/build.rs b/build.rs new file mode 100644 index 00000000..cab7cd69 --- /dev/null +++ b/build.rs @@ -0,0 +1,14 @@ +use std::{env, path::Path}; + +fn main() { + let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR is not set"); + let manifest_dir = Path::new(&manifest_dir); + let profile = env::var("PROFILE").expect("PROFILE environment variable is not set"); + + let file_path = manifest_dir.join("target").join(profile).join("build"); + + println!( + "cargo:rustc-env=TAURI_INVOKE_PROC_DIR={}", + file_path.to_string_lossy() + ); +} diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 9de316fb..6ae837ef 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -49,6 +49,8 @@ tracing-subscriber = { features = [ "env-filter", ], default-features = false, version = "0.3.19" } rustls = { version = "0.23.20", features = ["ring"] } +tauri-invoke-proc = { path = "./tauri-invoke-proc" } + [build-dependencies.tauri-build] version = "2.0.3" diff --git a/src-tauri/extensions/src/lib.rs b/src-tauri/extensions/src/lib.rs index 2beec4a4..5668a98e 100644 --- a/src-tauri/extensions/src/lib.rs +++ b/src-tauri/extensions/src/lib.rs @@ -26,10 +26,9 @@ use socket_handler::{ExtensionCommandReceiver, MainCommandSender, SocketHandler} use types::{ errors::{MoosyncError, Result}, extensions::{ - AccountLoginArgs, ContextMenuActionArgs, EmptyResp, ExtensionAccountDetail, - ExtensionContextMenuItem, ExtensionDetail, ExtensionExtraEventArgs, ExtensionManifest, - ExtensionProviderScope, FetchedExtensionManifest, GenericExtensionHostRequest, - PackageNameArgs, ToggleExtArgs, + AccountLoginArgs, ContextMenuActionArgs, ExtensionAccountDetail, ExtensionContextMenuItem, + ExtensionDetail, ExtensionExtraEventArgs, ExtensionManifest, ExtensionProviderScope, + FetchedExtensionManifest, GenericExtensionHostRequest, PackageNameArgs, ToggleExtArgs, }, }; use uuid::Uuid; diff --git a/src-tauri/macros/src/command_macro.rs b/src-tauri/macros/src/command_macro.rs index 4fa169bc..36962ebe 100644 --- a/src-tauri/macros/src/command_macro.rs +++ b/src-tauri/macros/src/command_macro.rs @@ -2,6 +2,7 @@ macro_rules! generate_command { ($method_name:ident, $state:ident, $ret:ty, $($v:ident: $t:ty),*) => { #[tracing::instrument(level = "trace", skip(db))] + #[tauri_invoke_proc::parse_tauri_command] #[tauri::command(async)] pub fn $method_name(db: State<$state>, $($v: $t),*) -> types::errors::Result<$ret> { tracing::debug!("calling {}", stringify!($method_name)); @@ -21,6 +22,7 @@ macro_rules! generate_command_cached { ($method_name:ident, $state:ident, $ret:ty, $($v:ident: $t:ty),*) => { // #[flame] #[tracing::instrument(level = "trace", skip(db, cache))] + #[tauri_invoke_proc::parse_tauri_command] #[tauri::command(async)] pub async fn $method_name(db: State<'_, $state>, cache: State<'_, CacheHolder>, $($v: $t),*) -> types::errors::Result<$ret> { let mut cache_string = String::new(); @@ -62,6 +64,7 @@ macro_rules! generate_command_async { ($method_name:ident, $state:ident, $ret:ty, $($v:ident: $t:ty),*) => { // #[flame] #[tracing::instrument(level = "trace", skip(db))] + #[tauri_invoke_proc::parse_tauri_command] #[tauri::command(async)] pub async fn $method_name(db: State<'_, $state>, $($v: $t),*) -> types::errors::Result<$ret> { tracing::debug!("calling async {}", stringify!($method_name)); @@ -81,6 +84,7 @@ macro_rules! generate_command_async_cached { ($method_name:ident, $state:ident, $ret:ty, $($v:ident: $t:ty),*) => { // #[flame] #[tracing::instrument(level = "trace", skip(db, cache))] + #[tauri_invoke_proc::parse_tauri_command] #[tauri::command(async)] pub async fn $method_name(db: State<'_, $state>, cache: State<'_, CacheHolder>, $($v: $t),*) -> types::errors::Result<$ret> { let mut cache_string = String::new(); diff --git a/src-tauri/src/db/mod.rs b/src-tauri/src/db/mod.rs index 6ff0c10a..4072d528 100644 --- a/src-tauri/src/db/mod.rs +++ b/src-tauri/src/db/mod.rs @@ -1,8 +1,7 @@ -use std::fs; - use database::{cache::CacheHolder, database::Database}; use macros::generate_command; use serde_json::Value; +use std::fs; use tauri::{App, AppHandle, Manager, State}; use tracing::{info, trace}; use types::errors::Result; @@ -16,6 +15,7 @@ use types::{ use crate::window::handler::WindowHandler; #[tracing::instrument(level = "trace", skip(app, db, window_handler))] +#[tauri_invoke_proc::parse_tauri_command] #[tauri::command(async)] pub fn export_playlist( app: AppHandle, diff --git a/src-tauri/src/providers/extension.rs b/src-tauri/src/providers/extension.rs index f2814e2e..b695ba44 100644 --- a/src-tauri/src/providers/extension.rs +++ b/src-tauri/src/providers/extension.rs @@ -325,16 +325,16 @@ impl GenericProvider for ExtensionProvider { async fn get_album_content( &self, - album: QueryableAlbum, - pagination: Pagination, + _album: QueryableAlbum, + _pagination: Pagination, ) -> Result<(Vec, Pagination)> { todo!() } async fn get_artist_content( &self, - artist: QueryableArtist, - pagination: Pagination, + _artist: QueryableArtist, + _pagination: Pagination, ) -> Result<(Vec, Pagination)> { todo!() } diff --git a/src-tauri/src/providers/spotify.rs b/src-tauri/src/providers/spotify.rs index 4bcd4911..89f0173b 100644 --- a/src-tauri/src/providers/spotify.rs +++ b/src-tauri/src/providers/spotify.rs @@ -15,9 +15,9 @@ use regex::Regex; use rspotify::{ clients::{BaseClient, OAuthClient}, model::{ - AlbumId, AlbumType, ArtistId, FullAlbum, FullArtist, FullTrack, Id, PlaylistId, - PlaylistTracksRef, SearchType, SimplifiedAlbum, SimplifiedArtist, SimplifiedPlaylist, - SimplifiedTrack, TrackId, + AlbumId, ArtistId, FullAlbum, FullArtist, FullTrack, Id, PlaylistId, PlaylistTracksRef, + SearchType, SimplifiedAlbum, SimplifiedArtist, SimplifiedPlaylist, SimplifiedTrack, + TrackId, }, AuthCodePkceSpotify, Token, }; @@ -772,7 +772,7 @@ impl GenericProvider for SpotifyProvider { if let Some(api_client) = &self.api_client { if let Some(next_page_token) = &pagination.token { // TODO: Fetch next pages - let tokens = next_page_token.split(";").collect::>(); + let _tokens = next_page_token.split(";").collect::>(); return Ok((vec![], pagination.next_page_wtoken(None))); } diff --git a/src-tauri/src/scanner/mod.rs b/src-tauri/src/scanner/mod.rs index c1780f96..b0cf2c79 100644 --- a/src-tauri/src/scanner/mod.rs +++ b/src-tauri/src/scanner/mod.rs @@ -1,10 +1,6 @@ use std::{ - sync::{ - atomic::AtomicBool, - mpsc::{Receiver, Sender}, - Arc, Mutex, - }, - thread::{self, JoinHandle}, + sync::{atomic::AtomicBool, Arc, Mutex}, + thread::{self}, time::Duration, }; diff --git a/src-tauri/tauri-invoke-proc/Cargo.toml b/src-tauri/tauri-invoke-proc/Cargo.toml new file mode 100644 index 00000000..71174897 --- /dev/null +++ b/src-tauri/tauri-invoke-proc/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "tauri-invoke-proc" +version = "0.1.0" +edition = "2021" + +[dependencies] +ctor = "0.2.9" +lazy_static = "1.5.0" +once_cell = "1.20.2" +proc-macro2 = "1.0.92" +quote = "1.0.37" +regex = "1.11.1" +serde = "1.0.216" +serde_json = "1.0.133" +syn = "2.0.90" +tracing = "0.1.41" + +[lib] +proc-macro = true diff --git a/src-tauri/tauri-invoke-proc/build.rs b/src-tauri/tauri-invoke-proc/build.rs new file mode 100644 index 00000000..af721ef4 --- /dev/null +++ b/src-tauri/tauri-invoke-proc/build.rs @@ -0,0 +1,11 @@ +use std::env; +use std::path::Path; + +fn main() { + let out_dir = env::var("OUT_DIR").expect("OUT_DIR environment variable is not set"); + let output_file = Path::new(&out_dir) + .join("../../") + .join("function_details.json"); + + println!("cargo:rerun-if-changed={}", output_file.to_string_lossy()); +} diff --git a/src-tauri/tauri-invoke-proc/src/common.rs b/src-tauri/tauri-invoke-proc/src/common.rs new file mode 100644 index 00000000..2ceea6a0 --- /dev/null +++ b/src-tauri/tauri-invoke-proc/src/common.rs @@ -0,0 +1,15 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct FnDetails { + pub name: String, + pub args: Vec, + pub ret: Option, +} + +#[derive(Serialize, Deserialize, Debug)] + +pub struct FnArgs { + pub name: String, + pub arg_type: String, +} diff --git a/src-tauri/tauri-invoke-proc/src/core.rs b/src-tauri/tauri-invoke-proc/src/core.rs new file mode 100644 index 00000000..52d487c5 --- /dev/null +++ b/src-tauri/tauri-invoke-proc/src/core.rs @@ -0,0 +1,292 @@ +/// Completely written using ChatGPT +/// Except the use statement parsing logic +use once_cell::sync::Lazy; +use proc_macro::TokenStream; +use quote::quote; +use regex::Regex; +use std::collections::{HashMap, HashSet}; +use std::path::Path; +use std::sync::Mutex; +use std::{env, fs}; +use syn::{parse_macro_input, FnArg, ItemFn, ReturnType, Type}; + +use crate::common::{FnArgs, FnDetails}; + +lazy_static::lazy_static! { + static ref FUNCTION_DETAILS: Mutex> = Mutex::new(Vec::new()); + static ref TYPE_CACHE: Lazy>> = Lazy::new(|| { + let crate_path = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR is not set"); + let src_path = Path::new(&crate_path).join("src"); + parse_use_statements_from_source(&src_path) + }); +} + +pub fn generate_tauri_invoke_wrapper(_attr: TokenStream, item: TokenStream) -> TokenStream { + // Parse the input TokenStream as a function + let input = parse_macro_input!(item as ItemFn); + + // Extract the function name + let fn_name = input.sig.ident.to_string(); + + // Extract the arguments and resolve their types + let args = input + .sig + .inputs + .iter() + .filter_map(|arg| match arg { + FnArg::Typed(pat_type) => { + if let syn::Pat::Ident(ident) = &*pat_type.pat { + let raw_type = resolve_type_path(&pat_type.ty); + Some(FnArgs { + name: ident.ident.to_string(), + arg_type: raw_type, + }) + } else { + None + } + } + _ => None, + }) + .collect::>(); + + let ret = match &input.sig.output { + ReturnType::Type(_, ty) => Some(resolve_type_path(ty)), + ReturnType::Default => None, // No return type (i.e., "-> ()") + }; + + // Add function details to the global variable + let details = FnDetails { + name: fn_name, + args, + ret, + }; + FUNCTION_DETAILS + .lock() + .expect("Failed to acquire lock on FUNCTION_DETAILS") + .push(details); + + // Return the original function unchanged + TokenStream::from(quote! { + #input + }) +} + +/// Resolves a raw type using the cached `use` map. If resolution fails, returns the raw type as-is. +fn resolve_with_cache(raw_type: &str) -> String { + let import = raw_type.split("::").last(); + if let Some(import) = import { + let imports = TYPE_CACHE.get(import); + if let Some(imports) = imports { + return imports + .iter() + .find(|i| i.contains(raw_type)) + .cloned() + .unwrap_or(raw_type.to_string()); + } + } + raw_type.to_string() +} + +/// Reads all `.rs` files in the `src` directory recursively and extracts `use` statements +fn parse_use_statements_from_source(src_path: &Path) -> HashMap> { + let mut use_map = HashMap::new(); + + fn recursive_scan(dir: &Path, use_map: &mut HashMap>) { + for entry in fs::read_dir(dir).expect("Failed to read directory") { + let entry = entry.expect("Failed to read directory entry"); + let path = entry.path(); + + if path.is_dir() { + recursive_scan(&path, use_map); + } else if path.is_file() && path.extension().map_or(false, |ext| ext == "rs") { + let source_code = fs::read_to_string(&path).expect("Failed to read source file"); + parse_use_statements(&source_code, use_map); + } + } + } + + recursive_scan(src_path, &mut use_map); + use_map +} + +/// Parses `use` statements from the given source code and populates the `use_map` +fn parse_use_statements(source: &str, use_map: &mut HashMap>) { + // Regex to match `use` statements, including those spanning multiple lines + let use_regex = Regex::new(r"use\s+([^;]+);").expect("Invalid regex"); + + // Find all matches for `use` statements + for caps in use_regex.captures_iter(source) { + let full_use_statement = &caps[1]; + process_use_statement(full_use_statement.trim(), use_map); + } +} + +/// Processes a single `use` statement recursively, handling nested structures +fn process_use_statement(statement: &str, use_map: &mut HashMap>) { + let mut queue = vec![]; + let mut ret = vec![]; + let mut last_index = 0; + for (i, c) in statement.char_indices() { + match c { + '{' => { + let end = statement[last_index..i].trim(); + if !end.is_empty() { + queue.push(end.to_string()); + } + last_index = i + 1; + } + '}' => { + let end = statement[last_index..i].trim(); + if !end.is_empty() { + ret.push(format!("{}{}", queue.join(""), end)); + } + queue.pop(); + last_index = i + 1; + } + ',' => { + let end = statement[last_index..i].trim(); + if !end.is_empty() { + ret.push(format!("{}{}", queue.join(""), end)); + } + last_index = i + 1; + } + _ => {} + } + } + + let rem = statement[last_index..statement.len()].trim(); + if !rem.is_empty() { + ret.push(rem.to_string()); + } + + for item in ret { + let import = item.split("::").last(); + if let Some(import) = import { + if use_map.contains_key(import) { + let existing = use_map.get_mut(import); + if let Some(existing) = existing { + existing.insert(item); + } + } else { + let mut hash_set = HashSet::new(); + hash_set.insert(item.clone()); + use_map.insert(import.to_string(), hash_set); + } + } + } +} + +/// Extracts the raw type as a string from a `syn::Type` +fn resolve_type_path(ty: &Box) -> String { + match &**ty { + // Handle simple type paths (e.g., `A`, `Vec`) + Type::Path(type_path) => { + // Base type (e.g., `Vec`, `HashMap`) + let resolved_segments = type_path + .path + .segments + .iter() + .map(|s| s.ident.to_string()) + .collect::>() + .join("::"); + + // Combine the resolved segments into a fully qualified path + let base_type = resolve_with_cache(&resolved_segments); + + // Check for generics (e.g., ``) + if let Some(generic_args) = type_path + .path + .segments + .last() // Generics are associated with the last segment + .and_then(|segment| match &segment.arguments { + syn::PathArguments::AngleBracketed(generics) => Some(&generics.args), + _ => None, + }) + { + // Process each generic argument recursively + let generic_types: Vec = generic_args + .iter() + .filter_map(|arg| match arg { + syn::GenericArgument::Type(inner_ty) => { + Some(resolve_type_path(&Box::new(inner_ty.clone()))) + } + _ => None, // Skip non-type generics (e.g., lifetimes) + }) + .collect(); + + // Return the base type with resolved generics + format!("{}<{}>", base_type, generic_types.join(", ")) + } else { + // No generics, return the resolved base type + base_type + } + } + Type::Tuple(type_tuple) => { + // Recursively resolve each element in the tuple + let resolved_elements: Vec = type_tuple + .elems + .iter() + .map(|elem| resolve_type_path(&Box::new(elem.clone()))) + .collect(); + + // Return the resolved tuple as `(A, B, C, ...)` + format!("({})", resolved_elements.join(", ")) + } + Type::Reference(type_ref) => { + // Skip `&` and `mut`, resolve the inner type + resolve_type_path(&type_ref.elem) + } + _ => quote::quote!(#ty).to_string(), + } +} + +// Ensure the global variable is written to the file at program exit +#[ctor::dtor] +fn write_function_details_to_file() { + // Retrieve the crate name from the environment variable + let crate_name = + std::env::var("CARGO_PKG_NAME").unwrap_or_else(|_| "unknown_crate".to_string()); + + // Lock the global FUNCTION_DETAILS and serialize its content + let data = FUNCTION_DETAILS + .lock() + .expect("Failed to acquire lock on FUNCTION_DETAILS"); + + // Convert the current FUNCTION_DETAILS to a JSON object under the crate name + let current_data = serde_json::json!({ crate_name.clone(): &*data }); + if current_data[crate_name.clone()] + .as_array() + .cloned() + .unwrap_or_default() + .is_empty() + { + return; + } + + // Path to the output JSON file + let out_dir = env::var("OUT_DIR").expect("OUT_DIR is not set"); + + // Construct the file path under `target` + let file_path = Path::new(&out_dir) + .join("../../") + .join("function_details.json"); + + // Read existing JSON file content, if any + let mut existing_data = match std::fs::read_to_string(file_path.clone()) { + Ok(content) => serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({})), + Err(_) => serde_json::json!({}), // If file doesn't exist, start with an empty object + }; + + // Merge the current data into the existing data + if let serde_json::Value::Object(ref mut map) = existing_data { + map.insert(crate_name.clone(), current_data[crate_name].clone()); + } else { + existing_data = current_data; + } + + // Serialize the updated data and write it back to the file + let json_output = + serde_json::to_string_pretty(&existing_data).expect("Failed to serialize JSON data"); + + std::fs::write(file_path, json_output).expect("Failed to write JSON data to file"); +} diff --git a/src-tauri/tauri-invoke-proc/src/lib.rs b/src-tauri/tauri-invoke-proc/src/lib.rs new file mode 100644 index 00000000..9e4acb6c --- /dev/null +++ b/src-tauri/tauri-invoke-proc/src/lib.rs @@ -0,0 +1,15 @@ +use proc_macro::TokenStream; + +mod common; +mod core; +mod ui; + +#[proc_macro_attribute] +pub fn parse_tauri_command(attr: TokenStream, item: TokenStream) -> TokenStream { + core::generate_tauri_invoke_wrapper(attr, item) +} + +#[proc_macro] +pub fn generate_tauri_invoke(item: TokenStream) -> TokenStream { + ui::generate_tauri_invoke_wrapper(item) +} diff --git a/src-tauri/tauri-invoke-proc/src/ui.rs b/src-tauri/tauri-invoke-proc/src/ui.rs new file mode 100644 index 00000000..30ec0acd --- /dev/null +++ b/src-tauri/tauri-invoke-proc/src/ui.rs @@ -0,0 +1,352 @@ +use std::{collections::HashMap, env, fs, path::Path}; + +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::{quote, ToTokens}; +use syn::{parse_macro_input, ExprArray, Ident}; + +use crate::common::FnDetails; + +pub fn generate_tauri_invoke_wrapper(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ExprArray); + let valid_crates = input + .elems + .iter() + .filter_map(|e| { + if let syn::Expr::Lit(lit) = e { + if let syn::Lit::Str(lit) = &lit.lit { + return Some(lit.value()); + } + } + None + }) + .collect(); + + let out_dir = env::var("TAURI_INVOKE_PROC_DIR") + .expect("TAURI_INVOKE_PROC_DIR environment variable is not set"); + + let file_path = Path::new(&out_dir).join("function_details.json"); + + let json_content = fs::read_to_string(file_path).expect("Failed to read function_details.json"); + let json: HashMap> = serde_json::from_str(&json_content) + .expect("Failed to parse JSON from function_details.json"); + + let generated_fns = parse_crate(json, valid_crates); + + let output = quote! { + #(#generated_fns)* + }; + + output.into() +} + +fn parse_crate( + json: HashMap>, + valid_crates: Vec, +) -> Vec { + let mut ret = vec![]; + for funcs in json.values() { + for func in funcs { + ret.push(parse_fn(func, &valid_crates)); + } + } + + ret +} + +#[derive(Clone)] +struct FnNameArg { + name: proc_macro2::Ident, + typ: syn::Type, +} + +impl ToTokens for FnNameArg { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let ident = &self.name; + let typ = &self.typ; // Destructure the tuple + let generated = quote! { + #ident: #typ + }; + tokens.extend(generated); + } +} + +fn is_external_allowed(type_name: &str) -> bool { + matches!(type_name, "serde_json" | "std") +} + +fn is_primitive_type(type_name: &str) -> bool { + matches!( + type_name, + "bool" + | "char" + | "str" + | "String" + | "i8" + | "i16" + | "i32" + | "i64" + | "i128" + | "isize" + | "u8" + | "u16" + | "u32" + | "u64" + | "u128" + | "usize" + | "f32" + | "f64" + | "Option" + | "Vec" + ) +} + +fn is_valid_type(arg_type: &syn::Type, valid_crates: &Vec) -> bool { + match arg_type { + // Check simple type paths (e.g., String, serde_json::Value) + syn::Type::Path(type_path) => { + if let Some(last_segment) = type_path.path.segments.last() { + // Check for generics in the type (e.g., Vec) + if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments { + // Validate all generic arguments + let is_valid = args.args.iter().all(|arg| match arg { + syn::GenericArgument::Type(inner_type) => { + is_valid_type(inner_type, valid_crates) + } + _ => true, // Ignore lifetimes or other non-type arguments + }); + + if !is_valid { + return false; + } + } + } + + // Extract the leading segment (crate or module name) + if let Some(first_segment) = type_path.path.segments.first() { + let type_name = first_segment.ident.to_string(); + // Allow all primitive types implicitly + if is_primitive_type(&type_name) || is_external_allowed(&type_name) { + return true; + } + + // Check if the type belongs to a valid crate or is a primitive + if valid_crates + .iter() + .any(|valid| type_name.starts_with(valid)) + { + return true; + } + } + false + } + // Allow primitives explicitly + syn::Type::Reference(type_ref) => { + is_valid_type(&type_ref.elem, valid_crates) // Validate referenced type + } + syn::Type::Tuple(type_tuple) => { + // Validate all elements of a tuple (e.g., (valid_type, valid_type)) + type_tuple + .elems + .iter() + .all(|elem| is_valid_type(elem, valid_crates)) + } + _ => false, // Disallow all other types + } +} + +fn replace_serde_json_with_jsvalue( + ty: &syn::Type, + has_generics: bool, + generic_count: u64, +) -> (u64, syn::Type) { + let mut new_count = generic_count; + match ty { + syn::Type::Path(type_path) => { + let mut type_path = type_path.clone(); + + // Check if the type is `serde_json::Value` + if let Some(first_segment) = type_path.path.segments.first_mut() { + if first_segment.ident == "serde_json" { + if let syn::PathArguments::None = first_segment.arguments { + if type_path.path.segments.len() == 2 + && type_path.path.segments[1].ident == "Value" + { + if has_generics { + return ( + new_count + 1, + syn::parse_str(format!("T{}", new_count + 1).as_str()) + .expect("Failed to parse replacement generic type"), + ); + } + // Replace `serde_json::Value` with `serde_wasm_bindgen::JsValue` + return ( + new_count + 1, + syn::parse_str("wasm_bindgen::JsValue") + .expect("Failed to parse replacement type"), + ); + } + } + } + } + + // Recursively process generic arguments, if any + for segment in &mut type_path.path.segments { + if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments { + for generic_arg in &mut args.args { + if let syn::GenericArgument::Type(inner_ty) = generic_arg { + let (generic_count, new_val) = + replace_serde_json_with_jsvalue(inner_ty, has_generics, new_count); + *inner_ty = new_val; + new_count = generic_count; + } + } + } + } + + (new_count, syn::Type::Path(type_path)) + } + syn::Type::Reference(type_ref) => { + replace_serde_json_with_jsvalue(&type_ref.elem, has_generics, new_count) + // Validate referenced type + } + syn::Type::Tuple(type_tuple) => { + // Validate all elements of a tuple (e.g., (valid_type, valid_type)) + let mut type_tuple = type_tuple.clone(); + type_tuple.elems = type_tuple + .elems + .into_iter() + .map(|elem| { + let (generic_count, typ) = + replace_serde_json_with_jsvalue(&elem, has_generics, new_count); + new_count = generic_count; + typ + }) + .collect(); + (new_count, syn::Type::Tuple(type_tuple)) + } + _ => (new_count, ty.clone()), + } +} + +fn parse_fn(dets: &FnDetails, valid_crates: &Vec) -> proc_macro2::TokenStream { + let invoke_name_lit = dets.name.clone(); + let func_name_ident = syn::Ident::new(&dets.name, Span::call_site()); + + let mut generics_needed = 0; + + let args = dets + .args + .iter() + .filter_map(|arg| { + if !arg.arg_type.starts_with("tauri::") { + let arg_name = syn::Ident::new(&arg.name, proc_macro2::Span::call_site()); + + let mut arg_type = syn::parse_str::(&arg.arg_type).unwrap(); + if !is_valid_type(&arg_type, valid_crates) { + eprintln!( + "Found not allowed type {}. Parsing it as serde_json::Value", + arg.arg_type + ); + arg_type = syn::parse_str::("wasm_bindgen::JsValue").unwrap(); + } + + let (generic_count, arg_type) = replace_serde_json_with_jsvalue(&arg_type, true, 0); + generics_needed = generic_count; + return Some(FnNameArg { + name: arg_name, + typ: arg_type, + }); + } + None + }) + .collect::>(); + + let mut is_ret_value = false; + + // Parse return type, if present + let ret_type = if let Some(ret) = &dets.ret { + // Parse the return type string into a `syn::Type` + let mut parsed_ret_type = + syn::parse_str::(ret).expect("Failed to parse return type"); + if !is_valid_type(&parsed_ret_type, valid_crates) { + eprintln!( + "Found not allowed type {}. Parsing it as wasm_bindgen::JsValue", + ret + ); + is_ret_value = true; + parsed_ret_type = + syn::parse_str::("types::errors::Result") + .unwrap(); + } + + let (is_changed, parsed_ret_type) = + replace_serde_json_with_jsvalue(&parsed_ret_type, false, 0); + if is_changed > 0 { + is_ret_value = true; + } + quote! { -> #parsed_ret_type } + } else { + // No return type (i.e., `()`) + quote! {} + }; + + let binding = args.clone(); + let struct_fields = binding.into_iter().map(|arg| { + let arg_name = arg.name; + let arg_typ = arg.typ; + quote! { + pub #arg_name: #arg_typ + } + }); + + let binding = args.clone(); + let struct_values = binding.into_iter().map(|arg| { + let arg_name = arg.name; + quote! { + #arg_name + } + }); + + let ret_val = if is_ret_value { + quote! { + Ok(res) + } + } else { + quote! { + Ok(serde_wasm_bindgen::from_value(res)?) + } + }; + + let (generic_params, where_clause) = if generics_needed > 0 { + let generics = (1..=generics_needed) + .map(|i| Ident::new(&format!("T{}", i), Span::call_site())) + .collect::>(); + ( + quote! { <#(#generics),*> }, + quote! { + where #(#generics: serde::Serialize + 'static),* + }, + ) + } else { + (quote! {}, quote! {}) + }; + + quote! { + pub async fn #func_name_ident #generic_params(#(#args),*) #ret_type #where_clause { + #[derive(serde::Serialize)] + struct Args #generic_params #where_clause { + #(#struct_fields),* + } + + let args = serde_wasm_bindgen::to_value(&Args { + #(#struct_values),* + }).unwrap(); + + let res = crate::utils::common::invoke(#invoke_name_lit, args) + .await?; + + #ret_val + } + } +} diff --git a/src/app.rs b/src/app.rs index 632c7b4d..1a3bc70f 100644 --- a/src/app.rs +++ b/src/app.rs @@ -4,13 +4,13 @@ use crate::{ components::{ better_animated_outlet::AnimatedOutletSimultaneous, prefs::static_components::SettingRoutes, }, - console_info, pages::explore::Explore, players::librespot::LibrespotPlayer, store::ui_store::UiStore, utils::{ common::{emit, invoke, listen_event}, - prefs::{load_selective_async, watch_preferences}, + invoke::load_selective, + prefs::watch_preferences, }, }; use leptos::{ @@ -165,10 +165,8 @@ pub fn App() -> impl IntoView { provide_i18n_context::(); spawn_local(async move { - let id = load_selective_async("themes.active_theme".into()) - .await - .unwrap(); - handle_theme(id); + let id = load_selective("themes.active_theme".into()).await.unwrap(); + handle_theme(serde_wasm_bindgen::from_value(id).unwrap()); }); let ui_requests_unlisten = listen_event("ui-requests", move |data| { diff --git a/src/components/prefs/components.rs b/src/components/prefs/components.rs index f18e880a..438e1fce 100644 --- a/src/components/prefs/components.rs +++ b/src/components/prefs/components.rs @@ -15,7 +15,6 @@ use types::{ themes::ThemeDetails, window::DialogFilter, }; -use wasm_bindgen::JsValue; use wasm_bindgen_futures::spawn_local; use crate::{ @@ -28,6 +27,7 @@ use crate::{ utils::{ common::invoke, context_menu::ThemesContextMenu, + invoke::{get_installed_extensions, load_all_themes}, prefs::{ load_selective, open_file_browser, open_file_browser_single, save_selective, save_selective_number, @@ -344,10 +344,8 @@ pub fn ThemesPref( let all_themes: RwSignal> = create_rw_signal(Default::default()); let load_themes = move || { spawn_local(async move { - let themes = invoke("load_all_themes", JsValue::undefined()) - .await - .unwrap(); - all_themes.set(serde_wasm_bindgen::from_value(themes).unwrap()); + let themes = load_all_themes().await.unwrap(); + all_themes.set(themes); }) }; load_themes(); @@ -442,12 +440,7 @@ pub fn ExtensionPref(#[prop()] title: String, #[prop()] tooltip: String) -> impl let extensions = create_rw_signal::>(Default::default()); let fetch_extensions = move || { spawn_local(async move { - let res = invoke("get_installed_extensions", JsValue::undefined()) - .await - .unwrap(); - tracing::debug!("Got res {:?}", res); - let res = serde_wasm_bindgen::from_value::>>(res) - .unwrap(); + let res = get_installed_extensions().await.unwrap(); tracing::debug!("got extensions {:?}", res); extensions.set(res.values().flatten().cloned().collect()); }) diff --git a/src/components/songlist.rs b/src/components/songlist.rs index 03b5168e..1c665893 100644 --- a/src/components/songlist.rs +++ b/src/components/songlist.rs @@ -6,7 +6,7 @@ use leptos::{ ev::{keydown, keyup}, event_target_value, expect_context, html::Input, - use_context, view, window_event_listener, HtmlElement, IntoView, ReadSignal, RwSignal, Show, + use_context, view, window_event_listener, HtmlElement, IntoView, RwSignal, Show, SignalGet, SignalSet, SignalUpdate, }; use leptos_context_menu::ContextMenu; diff --git a/src/modals/discover_extensions.rs b/src/modals/discover_extensions.rs index 7039cf53..d0fcf106 100644 --- a/src/modals/discover_extensions.rs +++ b/src/modals/discover_extensions.rs @@ -1,18 +1,17 @@ use leptos::{component, create_rw_signal, spawn_local, view, For, IntoView, SignalGet, SignalSet}; use types::extensions::FetchedExtensionManifest; -use wasm_bindgen::JsValue; -use crate::{modals::common::GenericModal, utils::common::invoke}; +use crate::{ + modals::common::GenericModal, + utils::{common::invoke, invoke::get_extension_manifest}, +}; #[tracing::instrument(level = "trace", skip())] #[component] pub fn DiscoverExtensionsModal() -> impl IntoView { let extensions = create_rw_signal(vec![]); spawn_local(async move { - let res = invoke("get_extension_manifest", JsValue::undefined()) - .await - .unwrap(); - let res: Vec = serde_wasm_bindgen::from_value(res).unwrap(); + let res = get_extension_manifest().await.unwrap(); extensions.set(res); }); diff --git a/src/pages/albums.rs b/src/pages/albums.rs index a4ce9dfb..a7e46ccc 100644 --- a/src/pages/albums.rs +++ b/src/pages/albums.rs @@ -11,7 +11,7 @@ use crate::utils::common::fetch_infinite; use crate::utils::songs::get_songs_from_indices; use leptos::{ component, create_effect, create_memo, create_rw_signal, create_write_slice, expect_context, - view, IntoView, RwSignal, SignalGet, SignalUpdate, SignalUpdateUntracked, SignalWith, + view, IntoView, RwSignal, SignalGet, SignalUpdate, SignalWith, }; use leptos_router::use_query_map; use rand::seq::SliceRandom; diff --git a/src/players/librespot.rs b/src/players/librespot.rs index 1eaebaea..99f4698e 100644 --- a/src/players/librespot.rs +++ b/src/players/librespot.rs @@ -8,7 +8,7 @@ use wasm_bindgen_futures::spawn_local; use crate::utils::{ common::{invoke, listen_event}, - prefs::load_selective_async, + invoke::{is_initialized, librespot_pause, librespot_play, load_selective}, }; use super::generic::GenericPlayer; @@ -107,14 +107,13 @@ impl LibrespotPlayer { fn initialize_librespot() { if *ENABLED.lock().unwrap() { spawn_local(async move { - let res = invoke("is_initialized", JsValue::undefined()).await; + let res = is_initialized().await; tracing::debug!("Librespot initialized: {:?}", res); - if let Ok(res) = res { - if let Some(initialized) = res.as_bool() { - *INITIALIZED.lock().unwrap() = initialized; - return; - } + if let Ok(initialized) = res { + *INITIALIZED.lock().unwrap() = initialized; + return; } + *INITIALIZED.lock().unwrap() = false; }) } @@ -173,12 +172,19 @@ impl GenericPlayer for LibrespotPlayer { #[tracing::instrument(level = "trace", skip(self))] fn initialize(&self, _: leptos::NodeRef) { spawn_local(async move { - let enabled: Vec = load_selective_async("spotify.enable".into()) - .await - .unwrap_or(vec![CheckboxPreference { + let data = load_selective("spotify.enable".into()).await; + + let enabled: Vec = if let Ok(data) = data { + serde_wasm_bindgen::from_value(data).unwrap_or(vec![CheckboxPreference { + key: "enable".into(), + enabled: true, + }]) + } else { + vec![CheckboxPreference { key: "enable".into(), enabled: true, - }]); + }] + }; for pref in enabled { if pref.key == "enable" { LibrespotPlayer::set_enabled(pref.enabled) @@ -224,7 +230,7 @@ impl GenericPlayer for LibrespotPlayer { #[tracing::instrument(level = "trace", skip(self))] fn play(&self) -> types::errors::Result<()> { spawn_local(async move { - let res = invoke("librespot_play", JsValue::undefined()).await; + let res = librespot_play().await; if res.is_err() { tracing::error!("Error playing {:?}", res.unwrap_err()); @@ -236,7 +242,7 @@ impl GenericPlayer for LibrespotPlayer { #[tracing::instrument(level = "trace", skip(self))] fn pause(&self) -> types::errors::Result<()> { spawn_local(async move { - let res = invoke("librespot_pause", JsValue::undefined()).await; + let res = librespot_pause().await; if res.is_err() { tracing::error!("Error pausing {:?}", res.unwrap_err()); diff --git a/src/players/rodio.rs b/src/players/rodio.rs index 1f76290c..aad68ab9 100644 --- a/src/players/rodio.rs +++ b/src/players/rodio.rs @@ -5,7 +5,10 @@ use serde::Serialize; use types::{songs::SongType, ui::player_details::PlayerEvents}; use wasm_bindgen::JsValue; -use crate::utils::common::{convert_file_src, invoke, listen_event}; +use crate::utils::{ + common::{convert_file_src, invoke, listen_event}, + invoke::{rodio_pause, rodio_play, rodio_stop}, +}; use super::generic::GenericPlayer; @@ -67,7 +70,7 @@ impl GenericPlayer for RodioPlayer { } spawn_local(async move { - let res = invoke("rodio_stop", JsValue::undefined()).await; + let res = rodio_stop().await; if res.is_err() { tracing::error!("Error stopping {:?}", res.unwrap_err()); @@ -89,7 +92,7 @@ impl GenericPlayer for RodioPlayer { #[tracing::instrument(level = "trace", skip(self))] fn play(&self) -> types::errors::Result<()> { spawn_local(async move { - let res = invoke("rodio_play", JsValue::undefined()).await; + let res = rodio_play().await; if res.is_err() { tracing::error!("Error playing {:?}", res.unwrap_err()); @@ -101,7 +104,7 @@ impl GenericPlayer for RodioPlayer { #[tracing::instrument(level = "trace", skip(self))] fn pause(&self) -> types::errors::Result<()> { spawn_local(async move { - let res = invoke("rodio_pause", JsValue::undefined()).await; + let res = rodio_pause().await; if res.is_err() { tracing::error!("Error playing {:?}", res.unwrap_err()); diff --git a/src/store/player_store.rs b/src/store/player_store.rs index 9df48b80..9a094a88 100644 --- a/src/store/player_store.rs +++ b/src/store/player_store.rs @@ -5,13 +5,12 @@ use indexed_db_futures::{ IdbDatabase, IdbVersionChangeEvent, }; use leptos::{ - create_effect, create_read_slice, create_rw_signal, RwSignal, SignalGet, SignalSet, + create_effect, create_rw_signal, RwSignal, SignalGet, SignalSet, SignalUpdate, }; use rand::seq::SliceRandom; use serde::Serialize; use std::{cmp::min, collections::HashMap, rc::Rc}; -use tracing::debug; use types::{ extensions::ExtensionExtraEvent, preferences::CheckboxPreference, @@ -22,7 +21,6 @@ use wasm_bindgen::JsValue; use wasm_bindgen_futures::spawn_local; use crate::utils::{ - common::info, db_utils::{read_from_indexed_db, write_to_indexed_db}, extensions::send_extension_event, mpris::{set_metadata, set_playback_state, set_position}, diff --git a/src/store/provider_store.rs b/src/store/provider_store.rs index b44cd445..bd6a9319 100644 --- a/src/store/provider_store.rs +++ b/src/store/provider_store.rs @@ -15,6 +15,9 @@ use wasm_bindgen::JsValue; use crate::players::librespot::LibrespotPlayer; use crate::store::modal_store::{ModalStore, Modals}; use crate::utils::common::{invoke, listen_event}; +use crate::utils::invoke::{ + get_all_status, get_provider_key_by_id, get_provider_keys, initialize_all_providers, +}; #[derive(Debug, Default)] pub struct ProviderStore { @@ -78,12 +81,12 @@ impl ProviderStore { let fetch_provider_keys = move || { spawn_local(async move { - let provider_keys = invoke("get_provider_keys", JsValue::undefined()).await; + let provider_keys = get_provider_keys().await; if provider_keys.is_err() { tracing::debug!("Failed to get provider keys"); return; } - store.keys.set(from_value(provider_keys.unwrap()).unwrap()); + store.keys.set(provider_keys.unwrap()); tracing::debug!("Updated provider keys {:?}", store.keys.get()); }); }; @@ -121,16 +124,12 @@ impl ProviderStore { #[cfg(not(feature = "mock"))] { - let res = invoke("initialize_all_providers", JsValue::undefined()).await; + let res = initialize_all_providers().await; if res.is_err() { tracing::error!("Failed to initialize providers"); } - let status = invoke("get_all_status", JsValue::undefined()) - .await - .unwrap(); + let statuses = get_all_status().await.unwrap(); - let statuses: HashMap = - serde_wasm_bindgen::from_value(status).unwrap(); store.statuses.set(statuses.values().cloned().collect()); store.is_initialized.set(true); } @@ -146,12 +145,7 @@ impl ProviderStore { #[tracing::instrument(level = "trace", skip(self, id))] pub async fn get_provider_key_by_id(&self, id: String) -> Result { - #[derive(Debug, Serialize)] - struct Args { - id: String, - } - let res = invoke("get_provider_key_by_id", to_value(&Args { id }).unwrap()).await?; - Ok(from_value(res)?) + get_provider_key_by_id(id).await } #[tracing::instrument(level = "trace", skip(self))] diff --git a/src/utils/context_menu.rs b/src/utils/context_menu.rs index aadc7625..d37ea7f9 100644 --- a/src/utils/context_menu.rs +++ b/src/utils/context_menu.rs @@ -1,4 +1,4 @@ -use leptos::{use_context, ReadSignal, RwSignal, SignalGet, SignalUpdate}; +use leptos::{use_context, RwSignal, SignalGet, SignalUpdate}; use leptos_context_menu::{ContextMenuData, ContextMenuItemInner, ContextMenuItems}; use leptos_router::{use_navigate, NavigateOptions}; use serde::Serialize; diff --git a/src/utils/db_utils.rs b/src/utils/db_utils.rs index e3054070..db45629c 100644 --- a/src/utils/db_utils.rs +++ b/src/utils/db_utils.rs @@ -5,7 +5,7 @@ use indexed_db_futures::IdbDatabase; use indexed_db_futures::IdbQuerySource; use leptos::{spawn_local, SignalSet, SignalUpdate}; use serde::Serialize; -use serde_wasm_bindgen::{from_value, to_value}; +use serde_wasm_bindgen::from_value; use types::entities::QueryableAlbum; use types::entities::QueryableArtist; use types::entities::QueryableGenre; @@ -17,36 +17,14 @@ use wasm_bindgen::JsValue; use web_sys::DomException; use web_sys::IdbTransactionMode; - -use super::common::invoke; - -#[derive(Serialize)] -struct GetSongOptionsArgs { - options: GetSongOptions, -} - -#[derive(Serialize)] -struct GetEntityOptionsArgs { - options: GetEntityOptions, -} - #[tracing::instrument(level = "trace", skip(options, setter))] #[cfg(not(feature = "mock"))] pub fn get_songs_by_option( options: GetSongOptions, setter: impl SignalSet> + 'static, ) { - - spawn_local(async move { - let args = to_value(&GetSongOptionsArgs { options }).unwrap(); - let res = invoke("get_songs_by_options", args).await; - if res.is_err() { - tracing::error!("Failed to load songs {:?}", res.unwrap_err()); - setter.set(vec![]); - return; - } - let songs: Vec = from_value(res.unwrap()).unwrap(); + let songs = super::invoke::get_songs_by_options(options).await.unwrap(); setter.set(songs); }); } @@ -160,19 +138,15 @@ where + 'static, { spawn_local(async move { - let args = to_value(&GetEntityOptionsArgs { - options: GetEntityOptions { + let songs = serde_wasm_bindgen::from_value( + super::invoke::get_entity_by_options(GetEntityOptions { playlist: Some(QueryablePlaylist::default()), ..Default::default() - }, - }) + }) + .await + .unwrap(), + ) .unwrap(); - let res = invoke("get_entity_by_options", args).await; - if res.is_err() { - tracing::error!("Error getting playlists: {:?}", res); - return; - } - let songs: Vec = from_value(res.unwrap()).unwrap(); setter.set(songs); }); } @@ -193,14 +167,11 @@ where let provider_store = expect_context::>(); spawn_local(async move { - let args = to_value(&GetEntityOptionsArgs { - options: GetEntityOptions { - playlist: Some(options), - ..Default::default() - }, + let res = super::invoke::get_entity_by_options(GetEntityOptions { + playlist: Some(options), + ..Default::default() }) - .unwrap(); - let res = invoke("get_entity_by_options", args).await; + .await; if res.is_err() { tracing::error!("Error getting playlists: {:?}", res); return; @@ -223,17 +194,12 @@ pub fn get_artists_by_option( options: QueryableArtist, setter: impl SignalSet> + 'static, ) { - - spawn_local(async move { - let args = to_value(&GetEntityOptionsArgs { - options: GetEntityOptions { - artist: Some(options), - ..Default::default() - }, + let res = super::invoke::get_entity_by_options(GetEntityOptions { + artist: Some(options), + ..Default::default() }) - .unwrap(); - let res = invoke("get_entity_by_options", args).await; + .await; if res.is_err() { tracing::error!("Error getting artists: {:?}", res); return; @@ -249,17 +215,12 @@ pub fn get_albums_by_option( options: QueryableAlbum, setter: impl SignalSet> + 'static, ) { - - spawn_local(async move { - let args = to_value(&GetEntityOptionsArgs { - options: GetEntityOptions { - album: Some(options), - ..Default::default() - }, + let res = super::invoke::get_entity_by_options(GetEntityOptions { + album: Some(options), + ..Default::default() }) - .unwrap(); - let res = invoke("get_entity_by_options", args).await; + .await; if res.is_err() { tracing::error!("Error getting albums: {:?}", res); return; @@ -275,17 +236,12 @@ pub fn get_genres_by_option( options: QueryableGenre, setter: impl SignalSet> + 'static, ) { - - spawn_local(async move { - let args = to_value(&GetEntityOptionsArgs { - options: GetEntityOptions { - genre: Some(options), - ..Default::default() - }, + let res = super::invoke::get_entity_by_options(GetEntityOptions { + genre: Some(options), + ..Default::default() }) - .unwrap(); - let res = invoke("get_entity_by_options", args).await; + .await; if res.is_err() { tracing::error!("Error getting genres: {:?}", res); return; @@ -297,17 +253,8 @@ pub fn get_genres_by_option( #[tracing::instrument(level = "trace", skip(songs))] pub fn add_songs_to_library(songs: Vec) { - #[derive(Serialize)] - struct AddSongsArgs { - songs: Vec, - } - spawn_local(async move { - let res = invoke( - "insert_songs", - serde_wasm_bindgen::to_value(&AddSongsArgs { songs }).unwrap(), - ) - .await; + let res = super::invoke::insert_songs(songs).await; if res.is_err() { tracing::error!("Error adding songs: {:?}", res); } @@ -316,20 +263,12 @@ pub fn add_songs_to_library(songs: Vec) { #[tracing::instrument(level = "trace", skip(songs))] pub fn remove_songs_from_library(songs: Vec) { - #[derive(Serialize)] - struct RemoveSongsArgs { - songs: Vec, - } spawn_local(async move { - let res = invoke( - "remove_songs", - serde_wasm_bindgen::to_value(&RemoveSongsArgs { - songs: songs - .iter() - .map(|s| s.song._id.clone().unwrap_or_default()) - .collect(), - }) - .unwrap(), + let res = super::invoke::remove_songs( + songs + .iter() + .map(|s| s.song._id.clone().unwrap_or_default()) + .collect(), ) .await; if res.is_err() { @@ -340,17 +279,8 @@ pub fn remove_songs_from_library(songs: Vec) { #[tracing::instrument(level = "trace", skip(id, songs))] pub fn add_to_playlist(id: String, songs: Vec) { - #[derive(Serialize)] - struct AddToPlaylistArgs { - id: String, - songs: Vec, - } spawn_local(async move { - let res = invoke( - "add_to_playlist", - serde_wasm_bindgen::to_value(&AddToPlaylistArgs { id, songs }).unwrap(), - ) - .await; + let res = super::invoke::add_to_playlist(id, songs).await; if res.is_err() { tracing::error!("Error adding to playlist: {:?}", res); } @@ -360,16 +290,7 @@ pub fn add_to_playlist(id: String, songs: Vec) { #[tracing::instrument(level = "trace", skip(playlist))] pub fn create_playlist(playlist: QueryablePlaylist) { spawn_local(async move { - #[derive(Serialize)] - struct CreatePlaylistArgs { - playlist: QueryablePlaylist, - } - - let res = invoke( - "create_playlist", - serde_wasm_bindgen::to_value(&CreatePlaylistArgs { playlist }).unwrap(), - ) - .await; + let res = super::invoke::create_playlist(playlist).await; if let Err(res) = res { tracing::error!("Failed to create playlist: {:?}", res); } @@ -383,19 +304,7 @@ pub fn remove_playlist(playlist: QueryablePlaylist) { } spawn_local(async move { - #[derive(Serialize)] - struct RemovePlaylistArgs { - id: String, - } - - let res = invoke( - "remove_playlist", - serde_wasm_bindgen::to_value(&RemovePlaylistArgs { - id: playlist.playlist_id.unwrap(), - }) - .unwrap(), - ) - .await; + let res = super::invoke::remove_playlist(playlist.playlist_id.unwrap()).await; if let Err(res) = res { tracing::error!("Failed to remove playlist: {:?}", res); } @@ -405,19 +314,7 @@ pub fn remove_playlist(playlist: QueryablePlaylist) { #[tracing::instrument(level = "trace", skip(playlist))] pub fn export_playlist(playlist: QueryablePlaylist) { spawn_local(async move { - #[derive(Serialize)] - struct ExportPlaylistArgs { - id: String, - } - - let res = invoke( - "export_playlist", - serde_wasm_bindgen::to_value(&ExportPlaylistArgs { - id: playlist.playlist_id.unwrap(), - }) - .unwrap(), - ) - .await; + let res = super::invoke::export_playlist(playlist.playlist_id.unwrap()).await; if let Err(res) = res { tracing::error!("Failed to export playlist: {:?}", res); } diff --git a/src/utils/invoke.rs b/src/utils/invoke.rs new file mode 100644 index 00000000..2d423301 --- /dev/null +++ b/src/utils/invoke.rs @@ -0,0 +1 @@ +tauri_invoke_proc::generate_tauri_invoke!(["types"]); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d0b47bfe..491285aa 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -3,6 +3,7 @@ pub mod context_menu; pub mod db_utils; pub mod entities; pub mod extensions; +pub mod invoke; pub mod mpris; pub mod prefs; pub mod providers; diff --git a/src/utils/prefs.rs b/src/utils/prefs.rs index b52d1f8a..4757867a 100644 --- a/src/utils/prefs.rs +++ b/src/utils/prefs.rs @@ -2,54 +2,27 @@ use std::rc::Rc; use leptos::{spawn_local, SignalSet}; use serde::{de::DeserializeOwned, Serialize}; -use serde_wasm_bindgen::{from_value, to_value}; -use types::errors::Result; use types::themes::ThemeDetails; -use types::window::{DialogFilter, FileResponse}; +use types::window::DialogFilter; use wasm_bindgen::JsValue; use crate::utils::common::listen_event; -use super::common::invoke; - -#[derive(Serialize)] -struct KeyArgs { - key: String, -} - -#[derive(Serialize)] -struct SetKeyArgs { - key: String, - value: T, -} - #[tracing::instrument(level = "trace", skip(key, setter))] pub fn load_selective(key: String, setter: impl SignalSet + 'static) where T: DeserializeOwned, { spawn_local(async move { - let res = load_selective_async(key.clone()).await; + let res = super::invoke::load_selective(key.clone()).await; if let Err(e) = res { tracing::error!("Failed to load preference: {}: {:?}", key, e); return; } - setter.set(res.unwrap()); + setter.set(serde_wasm_bindgen::from_value(res.unwrap()).unwrap()); }); } -#[tracing::instrument(level = "trace", skip(key))] -pub async fn load_selective_async(key: String) -> Result -where - T: DeserializeOwned, -{ - let args = to_value(&KeyArgs { key: key.clone() }).unwrap(); - let res = invoke("load_selective", args).await?; - let parsed = serde_wasm_bindgen::from_value(res); - - Ok(parsed?) -} - #[tracing::instrument(level = "trace", skip(key, value))] pub fn save_selective_number(key: String, value: String) { let val = value.parse::().unwrap(); @@ -63,12 +36,10 @@ where T: Serialize + 'static, { spawn_local(async move { - let args = to_value(&SetKeyArgs { - key: key.clone(), - value, - }) - .unwrap(); - let _ = invoke("save_selective", args).await; + let res = super::invoke::save_selective(key.clone(), Some(value)).await; + if let Err(e) = res { + tracing::error!("Error saving selective {}: {:?}", key, e); + } }); } @@ -90,28 +61,14 @@ pub fn open_file_browser( filters: Vec, setter: impl SignalSet> + 'static, ) { - #[derive(Serialize)] - struct FileBrowserArgs { - directory: bool, - multiple: bool, - filters: Vec, - } spawn_local(async move { - let args = to_value(&FileBrowserArgs { - directory, - multiple, - filters, - }) - .unwrap(); - - let res = invoke("open_file_browser", args).await; + let res = super::invoke::open_file_browser(directory, multiple, filters).await; if res.is_err() { tracing::error!("Failed to open file browser"); return; } - let file_resp: Vec = from_value(res.unwrap()).unwrap(); - tracing::debug!("Got file response {:?}", file_resp); - setter.set(file_resp.iter().map(|f| f.path.clone()).collect()); + tracing::debug!("Got file response {:?}", res); + setter.set(res.unwrap().iter().map(|f| f.path.clone()).collect()); }) } @@ -121,27 +78,13 @@ pub fn open_file_browser_single( filters: Vec, setter: impl SignalSet + 'static, ) { - #[derive(Serialize)] - struct FileBrowserArgs { - directory: bool, - multiple: bool, - filters: Vec, - } spawn_local(async move { - let args = to_value(&FileBrowserArgs { - directory, - multiple: false, - filters, - }) - .unwrap(); - - let res = invoke("open_file_browser", args).await; - if res.is_err() { + let file_resp = super::invoke::open_file_browser(directory, false, filters).await; + if file_resp.is_err() { tracing::error!("Failed to open file browser"); return; } - let file_resp: Vec = from_value(res.unwrap()).unwrap(); - setter.set(file_resp.first().unwrap().path.clone()); + setter.set(file_resp.unwrap().first().unwrap().path.clone()); }) } @@ -168,14 +111,8 @@ where T: Fn() + 'static, { let cb = Rc::new(Box::new(cb)); - #[derive(Serialize)] - struct ImportThemeArgs { - path: String, - } spawn_local(async move { - let args = to_value(&ImportThemeArgs { path }).unwrap(); - - let res = invoke("import_theme", args).await; + let res = super::invoke::import_theme(path).await; if res.is_err() { tracing::error!("Failed to import theme"); } @@ -191,14 +128,8 @@ where T: Fn() + 'static, { let cb = Rc::new(Box::new(cb)); - #[derive(Serialize)] - struct SaveThemeArgs { - theme: ThemeDetails, - } spawn_local(async move { - let args = to_value(&SaveThemeArgs { theme }).unwrap(); - - let res = invoke("save_theme", args).await; + let res = super::invoke::save_theme(theme).await; if res.is_err() { tracing::error!("Failed to save theme"); }