diff --git a/crates/core/src/chainable_method.rs b/crates/core/src/chainable_method.rs new file mode 100644 index 000000000..e1e8f9724 --- /dev/null +++ b/crates/core/src/chainable_method.rs @@ -0,0 +1,201 @@ +use anyhow::{Result, bail}; +use std::collections::HashSet; +use std::fmt; +use wit_parser::{Function, FunctionKind, Resolve, WorldKey}; + +/// Structure used to parse the command line argument `--chainable-method` consistently +/// across guest generators. +#[cfg_attr(feature = "clap", derive(clap::Parser))] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +#[derive(Clone, Default, Debug)] +pub struct ChainableMethodFilterSet { + /// Determines which resource methods should have chaining enabled. + /// Chaining takes a WIT method import returning nothing, and modifies bindgen + /// in a language-dependent way to return `self` in the glue code. This does + /// not affect the ABI in any way. + /// + /// This option can be passed multiple times and additionally accepts + /// comma-separated values for each option passed. Each individual argument + /// passed here can be one of: + /// + /// - `all` - all applicable methods will be chainable + /// - `-all` - no methods will be chainable + /// - `foo:bar/baz#my-resource` - enable chaining for all methods in a resource + /// - `foo:bar/baz#my-resource.some-method` - enable chaining for particular method + /// + /// Options are processed in the order they are passed here, so if a method + /// matches two directives passed the least-specific one should be last. + #[cfg_attr( + feature = "clap", + arg( + long = "chainable-methods", + value_parser = parse_chainable_method, + value_delimiter =',', + value_name = "FILTER", + ), + )] + chainable_methods: Vec, + + #[cfg_attr(feature = "clap", arg(skip))] + #[cfg_attr(feature = "serde", serde(skip))] + used_options: HashSet, +} + +#[cfg(feature = "clap")] +fn parse_chainable_method(s: &str) -> Result { + Ok(ChainableMethod::parse(s)) +} + +impl ChainableMethodFilterSet { + /// Returns a set where all functions should be chainable or not depending on + /// `enable` provided. + pub fn all(enable: bool) -> ChainableMethodFilterSet { + ChainableMethodFilterSet { + chainable_methods: vec![ChainableMethod { + enabled: enable, + filter: ChainableMethodFilter::All, + }], + used_options: HashSet::new(), + } + } + + /// Returns whether the `func` provided should be made chainable + pub fn should_be_chainable( + &mut self, + resolve: &Resolve, + interface: Option<&WorldKey>, + func: &Function, + is_import: bool, + ) -> bool { + if !is_import { + return false; + } + + if func.result.is_some() { + return false; + } + + match func.kind { + FunctionKind::AsyncMethod(resource) | FunctionKind::Method(resource) => { + let interface_name = match interface.map(|key| resolve.name_world_key(key)) { + Some(str) => str + "#", + None => "".into(), + }; + + let resource_name_to_test = format!( + "{}{}", + interface_name, + resolve.types[resource].name.as_ref().unwrap() + ); + + let method_name_to_test = format!("{}{}", interface_name, func.name); + + for (i, opt) in self.chainable_methods.iter().enumerate() { + match &opt.filter { + ChainableMethodFilter::All => { + self.used_options.insert(i); + return opt.enabled; + } + ChainableMethodFilter::Resource(s) => { + if *s == resource_name_to_test { + self.used_options.insert(i); + return opt.enabled; + } + } + ChainableMethodFilter::Method(s) => { + if *s == method_name_to_test { + self.used_options.insert(i); + return opt.enabled; + } + } + }; + } + + return false; + } + _ => { + return false; + } + } + } + + /// Intended to be used in the header comment of generated code to help + /// indicate what options were specified. + pub fn debug_opts(&self) -> impl Iterator + '_ { + self.chainable_methods.iter().map(|opt| opt.to_string()) + } + + /// Tests whether all `--chainable-method` options were used throughout bindings + /// generation, returning an error if any were unused. + pub fn ensure_all_used(&self) -> Result<()> { + for (i, opt) in self.chainable_methods.iter().enumerate() { + if self.used_options.contains(&i) { + continue; + } + if !matches!(opt.filter, ChainableMethodFilter::All) { + bail!("unused chainable option: {opt}"); + } + } + Ok(()) + } + + /// Pushes a new option into this set. + pub fn push(&mut self, directive: &str) { + self.chainable_methods + .push(ChainableMethod::parse(directive)); + } +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +struct ChainableMethod { + enabled: bool, + filter: ChainableMethodFilter, +} + +impl ChainableMethod { + fn parse(s: &str) -> ChainableMethod { + let (s, enabled) = match s.strip_prefix('-') { + Some(s) => (s, false), + None => (s, true), + }; + let filter = match s { + "all" => ChainableMethodFilter::All, + other => { + if other.contains("[method]") { + ChainableMethodFilter::Method(other.to_string()) + } else { + ChainableMethodFilter::Resource(other.to_string()) + } + } + }; + ChainableMethod { enabled, filter } + } +} + +impl fmt::Display for ChainableMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if !self.enabled { + write!(f, "-")?; + } + self.filter.fmt(f) + } +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize))] +enum ChainableMethodFilter { + All, + Resource(String), + Method(String), +} + +impl fmt::Display for ChainableMethodFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ChainableMethodFilter::All => write!(f, "all"), + ChainableMethodFilter::Resource(s) => write!(f, "{s}"), + ChainableMethodFilter::Method(s) => write!(f, "{s}"), + } + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index b85754160..80d76d3f7 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -14,6 +14,8 @@ mod path; pub use path::name_package_module; mod async_; pub use async_::AsyncFilterSet; +mod chainable_method; +pub use chainable_method::ChainableMethodFilterSet; #[derive(Default, Copy, Clone, PartialEq, Eq, Debug)] pub enum Direction { diff --git a/crates/guest-rust/macro/src/lib.rs b/crates/guest-rust/macro/src/lib.rs index e5ae0b0ef..e5fc3f6ac 100644 --- a/crates/guest-rust/macro/src/lib.rs +++ b/crates/guest-rust/macro/src/lib.rs @@ -6,9 +6,9 @@ use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; use syn::parse::{Error, Parse, ParseStream, Result}; use syn::punctuated::Punctuated; use syn::{Token, braced, token}; -use wit_bindgen_core::AsyncFilterSet; use wit_bindgen_core::WorldGenerator; use wit_bindgen_core::wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId}; +use wit_bindgen_core::{AsyncFilterSet, ChainableMethodFilterSet}; use wit_bindgen_rust::{Opts, Ownership, WithOption}; #[proc_macro] @@ -66,6 +66,7 @@ impl Parse for Config { let mut source = None; let mut features = Vec::new(); let mut async_configured = false; + let mut method_chaining_configured = false; let mut debug = false; if input.peek(token::Brace) { @@ -165,8 +166,15 @@ impl Parse for Config { async_configured = true; opts.async_ = val; } - Opt::EnableMethodChaining(enable) => { - opts.enable_method_chaining = enable.value(); + Opt::ChainableMethods(val, span) => { + if method_chaining_configured { + return Err(Error::new( + span, + "cannot specify second method chaining config", + )); + } + method_chaining_configured = true; + opts.chainable_methods = val; } } } @@ -321,7 +329,7 @@ mod kw { syn::custom_keyword!(disable_custom_section_link_helpers); syn::custom_keyword!(imports); syn::custom_keyword!(debug); - syn::custom_keyword!(enable_method_chaining); + syn::custom_keyword!(chainable_methods); } #[derive(Clone)] @@ -402,7 +410,7 @@ enum Opt { DisableCustomSectionLinkHelpers(syn::LitBool), Async(AsyncFilterSet, Span), Debug(syn::LitBool), - EnableMethodChaining(syn::LitBool), + ChainableMethods(ChainableMethodFilterSet, Span), } impl Parse for Opt { @@ -567,10 +575,17 @@ impl Parse for Opt { input.parse::()?; input.parse::()?; Ok(Opt::Debug(input.parse()?)) - } else if l.peek(kw::enable_method_chaining) { - input.parse::()?; + } else if l.peek(kw::chainable_methods) { + let span = input.parse::()?.span; input.parse::()?; - Ok(Opt::EnableMethodChaining(input.parse()?)) + + let mut set = ChainableMethodFilterSet::default(); + let contents; + syn::bracketed!(contents in input); + for val in contents.parse_terminated(|p| p.parse::(), Token![,])? { + set.push(&val.value()); + } + Ok(Opt::ChainableMethods(set, span)) } else if l.peek(Token![async]) { let span = input.parse::()?.span; input.parse::()?; diff --git a/crates/rust/src/interface.rs b/crates/rust/src/interface.rs index 38caf7c14..de92fd140 100644 --- a/crates/rust/src/interface.rs +++ b/crates/rust/src/interface.rs @@ -179,8 +179,13 @@ impl<'i> InterfaceGenerator<'i> { private: true, ..Default::default() }; - sig.update_for_func(&func); - self.print_signature(func, true, &sig); + + let should_return_self = + self.r#gen + .should_return_self(self.resolve, interface.map(|p| p.1), func, false); + + sig.update_for_func(&func, should_return_self); + self.print_signature(func, true, &sig, should_return_self); self.src.push_str(";\n"); let trait_method = mem::replace(&mut self.src, prev); methods.push(trait_method); @@ -741,22 +746,37 @@ pub mod vtable{ordinal} {{ async_, ..Default::default() }; + + let should_return_self = self + .r#gen + .should_return_self(self.resolve, interface, func, true); + if let Some(id) = func.kind.resource() { let name = self.resolve.types[id].name.as_ref().unwrap(); let name = to_upper_camel_case(name); uwriteln!(self.src, "impl {name} {{"); sig.use_item_name = true; - sig.update_for_func(&func); + sig.update_for_func(&func, should_return_self); } self.src.push_str("#[allow(unused_unsafe, clippy::all)]\n"); - let params = self.print_signature(func, async_, &sig); + let params = self.print_signature(func, async_, &sig, should_return_self); self.src.push_str("{\n"); self.src.push_str("unsafe {\n"); if async_ { - self.generate_guest_import_body_async(&self.wasm_import_module, func, params); + self.generate_guest_import_body_async( + &self.wasm_import_module, + func, + params, + should_return_self, + ); } else { - self.generate_guest_import_body_sync(&self.wasm_import_module, func, params); + self.generate_guest_import_body_sync( + &self.wasm_import_module, + func, + params, + should_return_self, + ); } self.src.push_str("}\n"); @@ -808,14 +828,9 @@ pub mod vtable{ordinal} {{ module: &str, func: &Function, params: Vec, + should_return_self: bool, ) { - let mut f = FunctionBindgen::new( - self, - params, - module, - false, - self.r#gen.should_return_self(func), - ); + let mut f = FunctionBindgen::new(self, params, module, false, should_return_self); abi::call( f.r#gen.resolve, AbiVariant::GuestImport, @@ -856,6 +871,7 @@ pub mod vtable{ordinal} {{ module: &str, func: &Function, mut params: Vec, + should_return_self: bool, ) { let param_tys = func .params @@ -1086,11 +1102,7 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8) self.src, "_MySubtask {{ _unused: core::marker::PhantomData }}.call(({})).await{}", params.join(" "), - if self.r#gen.should_return_self(func) { - ";\nself" - } else { - "" - } + if should_return_self { ";\nself" } else { "" } ); } @@ -1387,9 +1399,14 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8) private: true, ..Default::default() }; - sig.update_for_func(&func); + + let should_return_self = + self.r#gen + .should_return_self(self.resolve, interface.map(|p| p.1), func, false); + + sig.update_for_func(&func, should_return_self); self.src.push_str("#[allow(unused_variables)]\n"); - self.print_signature(func, true, &sig); + self.print_signature(func, true, &sig, should_return_self); self.src.push_str("{ unreachable!() }\n"); } @@ -1447,8 +1464,14 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8) // } } - fn print_signature(&mut self, func: &Function, params_owned: bool, sig: &FnSig) -> Vec { - let params = self.print_docs_and_params(func, params_owned, sig); + fn print_signature( + &mut self, + func: &Function, + params_owned: bool, + sig: &FnSig, + should_return_self: bool, + ) -> Vec { + let params = self.print_docs_and_params(func, params_owned, sig, should_return_self); self.push_str(" -> "); if let FunctionKind::Constructor(resource_id) = &func.kind { match classify_constructor_return_type(&self.resolve, *resource_id, &func.result) { @@ -1462,8 +1485,8 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8) } } } else { - if self.r#gen.should_return_self(func) { - self.push_str("&Self"); + if should_return_self { + self.push_str("Self"); } else { self.print_result_type(&func.result); } @@ -1476,6 +1499,7 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8) func: &Function, params_owned: bool, sig: &FnSig, + should_return_self: bool, ) -> Vec { self.rustdoc(&func.docs); self.rustdoc_params(&func.params, "Parameters"); @@ -1523,7 +1547,11 @@ unsafe fn call_import(&mut self, _params: Self::ParamsLower, _results: *mut u8) ) in func.params.iter().enumerate() { if i == 0 && sig.self_is_first_param { - params.push("self".to_string()); + params.push(if should_return_self { + "&self".to_string() + } else { + "self".to_string() + }); continue; } let name = to_rust_ident(name); diff --git a/crates/rust/src/lib.rs b/crates/rust/src/lib.rs index 7241f2d09..fc97a6ff9 100644 --- a/crates/rust/src/lib.rs +++ b/crates/rust/src/lib.rs @@ -10,8 +10,8 @@ use std::path::{Path, PathBuf}; use std::str::FromStr; use wit_bindgen_core::abi::{Bitcast, WasmType}; use wit_bindgen_core::{ - AsyncFilterSet, Files, InterfaceGenerator as _, Source, Types, WorldGenerator, dealias, - name_package_module, uwrite, uwriteln, wit_parser::*, + AsyncFilterSet, ChainableMethodFilterSet, Files, InterfaceGenerator as _, Source, Types, + WorldGenerator, dealias, name_package_module, uwrite, uwriteln, wit_parser::*, }; mod bindgen; @@ -301,9 +301,9 @@ pub struct Opts { )] pub merge_structurally_equal_types: Option>, - /// If true, methods normally returning `()` instead return `&Self`. This applies to both imported and exported methods. - #[cfg_attr(feature = "clap", arg(long))] - pub enable_method_chaining: bool, + #[cfg_attr(feature = "clap", clap(flatten))] + #[cfg_attr(feature = "serde", serde(flatten))] + pub chainable_methods: ChainableMethodFilterSet, } impl Opts { @@ -1061,10 +1061,17 @@ macro_rules! __export_{world_name}_impl {{ .is_async(resolve, interface, func, is_import) } - fn should_return_self(&self, func: &Function) -> bool { - self.opts.enable_method_chaining - && func.result.is_none() - && matches!(&func.kind, FunctionKind::Method(_)) + fn should_return_self( + &mut self, + resolve: &Resolve, + interface: Option<&WorldKey>, + func: &Function, + is_import: bool, + ) -> bool { + return self + .opts + .chainable_methods + .should_be_chainable(resolve, interface, func, is_import); } } @@ -1533,9 +1540,10 @@ impl WorldGenerator for RustWasm { bail!("unused remappings provided via `with`: {unused_keys:?}"); } - // Error about unused async configuration to help catch configuration + // Error about unused async and method chaining configuration to help catch configuration // errors. self.opts.async_.ensure_all_used()?; + self.opts.chainable_methods.ensure_all_used()?; Ok(()) } @@ -1685,9 +1693,13 @@ struct FnSig { } impl FnSig { - fn update_for_func(&mut self, func: &Function) { + fn update_for_func(&mut self, func: &Function, return_self: bool) { if let FunctionKind::Method(_) | FunctionKind::AsyncMethod(_) = &func.kind { - self.self_arg = Some("&self".into()); + self.self_arg = Some(if return_self { + "self".into() + } else { + "&self".into() + }); self.self_is_first_param = true; } } diff --git a/crates/rust/tests/codegen.rs b/crates/rust/tests/codegen.rs index a6df22cf9..cf5ecadbc 100644 --- a/crates/rust/tests/codegen.rs +++ b/crates/rust/tests/codegen.rs @@ -232,6 +232,6 @@ mod method_chaining { } "#, generate_all, - enable_method_chaining: true + chainable_methods: ["all"] }); } diff --git a/tests/runtime/rust/method-chaining/runner.rs b/tests/runtime/rust/method-chaining/runner.rs index cbfaf631e..3a8a89130 100644 --- a/tests/runtime/rust/method-chaining/runner.rs +++ b/tests/runtime/rust/method-chaining/runner.rs @@ -1,8 +1,9 @@ -//@ args = '--enable-method-chaining' +//@ args = '--chainable-methods foo:bar/i#a' include!(env!("BINDINGS")); use crate::foo::bar::i::A; +use crate::foo::bar::i::B; struct Component; export!(Component); @@ -11,5 +12,10 @@ impl Guest for Component { fn run() { let my_a = A::new(); my_a.set_a(42).set_b(true).do_(); + + let my_b = B::new(); + my_b.set_a(42); + my_b.set_b(true); + my_b.do_(); } } diff --git a/tests/runtime/rust/method-chaining/test.rs b/tests/runtime/rust/method-chaining/test.rs index 9e48a05c9..4010cc73f 100644 --- a/tests/runtime/rust/method-chaining/test.rs +++ b/tests/runtime/rust/method-chaining/test.rs @@ -1,14 +1,17 @@ -//@ args = '--enable-method-chaining' +//@ args = '--chainable-methods all' + +// Should have no effect on exports include!(env!("BINDINGS")); -use crate::exports::foo::bar::i::{Guest, GuestA}; +use crate::exports::foo::bar::i::{Guest, GuestA, GuestB}; use std::cell::Cell; struct Component; export!(Component); impl Guest for Component { type A = MyA; + type B = MyB; } struct MyA { @@ -16,6 +19,11 @@ struct MyA { prop_b: Cell, } +struct MyB { + prop_a: Cell, + prop_b: Cell, +} + impl GuestA for MyA { fn new() -> MyA { MyA { @@ -24,17 +32,32 @@ impl GuestA for MyA { } } - fn set_a(&self, a: u32) -> &Self { + fn set_a(&self, a: u32) { self.prop_a.set(a); - self } - fn set_b(&self, b: bool) -> &Self { + fn set_b(&self, b: bool) { self.prop_b.set(b); - self } - fn do_(&self) -> &Self { - self + fn do_(&self) {} +} + +impl GuestB for MyB { + fn new() -> MyB { + MyB { + prop_a: Cell::new(0), + prop_b: Cell::new(false), + } + } + + fn set_a(&self, a: u32) { + self.prop_a.set(a); } + + fn set_b(&self, b: bool) { + self.prop_b.set(b); + } + + fn do_(&self) {} } diff --git a/tests/runtime/rust/method-chaining/test.wit b/tests/runtime/rust/method-chaining/test.wit index f3c18c846..4b7f1dc9d 100644 --- a/tests/runtime/rust/method-chaining/test.wit +++ b/tests/runtime/rust/method-chaining/test.wit @@ -7,6 +7,12 @@ interface i { set-b: func(arg: bool); do: func(); } + resource b { + constructor(); + set-a: func(arg: u32); + set-b: func(arg: bool); + do: func(); + } } world runner { import i;