diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index ebadda540..82874c659 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -21,11 +21,24 @@ use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, ParameterType}; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; +use hyperlight_guest::bail; use hyperlight_guest::error::{HyperlightGuestError, Result}; use tracing::instrument; use crate::{GUEST_HANDLE, REGISTERED_GUEST_FUNCTIONS}; +core::arch::global_asm!( + ".weak guest_dispatch_function", + ".set guest_dispatch_function, {}", + sym guest_dispatch_function_default, +); + +#[tracing::instrument(skip_all, parent = tracing::Span::current(), level= "Trace")] +fn guest_dispatch_function_default(function_call: FunctionCall) -> Result> { + let name = &function_call.function_name; + bail!(ErrorCode::GuestFunctionNotFound => "No handler found for function call: {name:#?}"); +} + #[instrument(skip_all, level = "Info")] pub(crate) fn call_guest_function(function_call: FunctionCall) -> Result> { // Validate this is a Guest Function Call diff --git a/src/hyperlight_guest_bin/src/lib.rs b/src/hyperlight_guest_bin/src/lib.rs index ccd58a0f8..e4605ee8f 100644 --- a/src/hyperlight_guest_bin/src/lib.rs +++ b/src/hyperlight_guest_bin/src/lib.rs @@ -197,6 +197,17 @@ unsafe extern "C" { fn srand(seed: u32); } +#[tracing::instrument(skip_all, parent = tracing::Span::current(), level= "Trace")] +extern "C" fn hyperlight_main_default() { + // no-op +} + +core::arch::global_asm!( + ".weak hyperlight_main", + ".set hyperlight_main, {}", + sym hyperlight_main_default, +); + /// Architecture-nonspecific initialisation: set up the heap, /// coordinate some addresses and configuration with the host, and run /// user initialisation @@ -286,6 +297,9 @@ pub(crate) extern "C" fn generic_init( #[cfg(feature = "macros")] #[doc(hidden)] pub mod __private { + pub use alloc::vec::Vec; + + pub use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; pub use hyperlight_common::func::ResultType; pub use hyperlight_guest::error::HyperlightGuestError; pub use linkme; @@ -299,7 +313,6 @@ pub mod __private { } use alloc::string::String; - use alloc::vec::Vec; use hyperlight_common::for_each_return_type; @@ -327,6 +340,6 @@ pub mod __private { } #[cfg(feature = "macros")] -pub use hyperlight_guest_macro::{guest_function, host_function}; +pub use hyperlight_guest_macro::{dispatch, guest_function, host_function, main}; pub use crate::guest_function::definition::GuestFunc; diff --git a/src/hyperlight_guest_macro/src/lib.rs b/src/hyperlight_guest_macro/src/lib.rs index 08e1247ce..e4f9404ab 100644 --- a/src/hyperlight_guest_macro/src/lib.rs +++ b/src/hyperlight_guest_macro/src/lib.rs @@ -156,6 +156,103 @@ pub fn guest_function(attr: TokenStream, item: TokenStream) -> TokenStream { output.into() } +/// Attribute macro to mark a function as the main entry point for the guest. +/// This will generate a function that is called by the host at program initialization. +/// +/// # Example +/// ```ignore +/// use hyperlight_guest_bin::main; +/// #[main] +/// fn main() { +/// // do some initialization work here, e.g., initialize global state, etc. +/// } +/// ``` +#[proc_macro_attribute] +pub fn main(_attr: TokenStream, item: TokenStream) -> TokenStream { + // Parse the function definition that we will be working with, and + // early return if parsing as `ItemFn` fails. + let fn_declaration = parse_macro_input!(item as ItemFn); + + // Obtain the name of the function being decorated. + let ident = fn_declaration.sig.ident.clone(); + + // The generated code will replace the decorated code, so we need to + // include the original function declaration in the output. + let output = quote! { + #fn_declaration + + const _: () = { + mod wrapper { + #[unsafe(no_mangle)] + pub extern "C" fn hyperlight_main() { + super::#ident() + } + } + }; + }; + + output.into() +} + +/// Attribute macro to mark a function as the dispatch function for the guest. +/// This is the function that will be called by the host when a function call is made +/// to a function that is not registered with the host. +/// +/// # Example +/// ```ignore +/// use hyperlight_guest_bin::dispatch; +/// use hyperlight_guest::error::Result; +/// use hyperlight_guest::bail; +/// use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; +/// use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; +/// #[dispatch] +/// fn dispatch(fc: FunctionCall) -> Result> { +/// let name = &fc.function_name; +/// if name == "greet" { +/// return Ok(get_flatbuffer_result("Hello, world!")); +/// } +/// bail!("Unknown function: {name}"); +/// } +/// ``` +#[proc_macro_attribute] +pub fn dispatch(_attr: TokenStream, item: TokenStream) -> TokenStream { + // Obtain the crate name for hyperlight-guest-bin + let crate_name = + crate_name("hyperlight-guest-bin").expect("hyperlight-guest-bin must be a dependency"); + let crate_name = match crate_name { + FoundCrate::Itself => quote! {crate}, + FoundCrate::Name(name) => { + let ident = syn::Ident::new(&name, proc_macro2::Span::call_site()); + quote! {::#ident} + } + }; + + // Parse the function definition that we will be working with, and + // early return if parsing as `ItemFn` fails. + let fn_declaration = parse_macro_input!(item as ItemFn); + + // Obtain the name of the function being decorated. + let ident = fn_declaration.sig.ident.clone(); + + // The generated code will replace the decorated code, so we need to + // include the original function declaration in the output. + let output = quote! { + #fn_declaration + + const _: () = { + mod wrapper { + use #crate_name::__private::{FunctionCall, HyperlightGuestError, Vec}; + #[unsafe(no_mangle)] + pub fn guest_dispatch_function(function_call: FunctionCall) -> ::core::result::Result, HyperlightGuestError> { + super::#ident(function_call) + } + } + }; + }; + + output.into() +} + /// Attribute macro to mark a function as a host function. /// This will generate a function that calls the host function with the same name. /// diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index ac876e65f..b6844a716 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -855,9 +855,9 @@ fn call_host_expect_error(hostfuncname: String) -> Result<()> { Ok(()) } -#[no_mangle] +#[hyperlight_guest_bin::main] #[instrument(skip_all, parent = Span::current(), level= "Trace")] -pub extern "C" fn hyperlight_main() { +fn main() { let print_output_def = GuestFunctionDefinition::::new( "PrintOutputWithHostPrint".to_string(), Vec::from(&[ParameterType::String]), @@ -1067,9 +1067,9 @@ fn fuzz_host_function(func: FunctionCall) -> Result> { } } -#[no_mangle] +#[hyperlight_guest_bin::dispatch] #[instrument(skip_all, parent = Span::current(), level= "Trace")] -pub fn guest_dispatch_function(function_call: FunctionCall) -> Result> { +fn dispatch(function_call: FunctionCall) -> Result> { // This test checks the stack behavior of the input/output buffer // by calling the host before serializing the function call. // If the stack is not working correctly, the input or output buffer will be