Skip to content

Commit

Permalink
wasmtime(gc): Fix wasm-to-native trampoline lookup for subtyping (#8579)
Browse files Browse the repository at this point in the history
* wasmtime(gc): Fix wasm-to-native trampoline lookup for subtyping

Previously, we would look up a wasm-to-native trampoline in the Wasm module
based on the host function's type. With Wasm GC and subtyping, this becomes
problematic because a Wasm module can import a function of type `T` but the host
can define a function of type `U` where `U <: T`. And if the Wasm has never
defined type `U` then it wouldn't have a trampoline for it. But our trampolines
don't actually care, they treat all reference values within the same type
hierarchy identically. So the trampoline for `T` would have worked in
practice. But once we find a trampoline for a function, we cache it and reuse it
every time that function is used in the same store again. Even if the function
is imported with its precise type somewhere else. So then we would have a
trampoline of the wrong type. But this happened to be okay in practice because
the trampolines happen not to inspect their arguments or do anything with them
other than forward them between calling convention locations. But relying on
that accidental invariant seems fragile and like a gun aimed at the future's
feet.

This commit makes that invariant non-accidental, centering it and hopefully
making it less fragile by doing so, by making every function type have an
associated "trampoline type". A trampoline type is the original function type
but where all the reference types in its params and results are replaced with
the nullable top versions, e.g. `(ref $my_struct)` is replaced with `(ref null
any)`. Often a function type is its own associated trampoline type, as is the
case for all functions that don't have take or return any references, for
example. Then, all trampoline lookup begins by first getting the trampoline type
of the actual function type, or actual import type, and then only afterwards
finding for the pre-compiled trampoline in the Wasm module.

Fixes #8432

Co-Authored-By: Jamey Sharp <[email protected]>

* Fix no-std build

---------

Co-authored-by: Jamey Sharp <[email protected]>
  • Loading branch information
fitzgen and jameysharp authored May 8, 2024
1 parent 3308a2b commit a947340
Show file tree
Hide file tree
Showing 10 changed files with 487 additions and 47 deletions.
82 changes: 79 additions & 3 deletions crates/environ/src/compile/module_types.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::{EntityRef, Module, ModuleTypes, TypeConvert};
use std::{collections::HashMap, ops::Index};
use std::{borrow::Cow, collections::HashMap, ops::Index};
use wasmparser::{UnpackedIndex, Validator, ValidatorId};
use wasmtime_types::{
EngineOrModuleTypeIndex, ModuleInternedRecGroupIndex, ModuleInternedTypeIndex, TypeIndex,
WasmCompositeType, WasmHeapType, WasmResult, WasmSubType,
WasmCompositeType, WasmFuncType, WasmHeapType, WasmResult, WasmSubType,
};

/// A type marking the start of a recursion group's definition.
Expand Down Expand Up @@ -31,6 +31,12 @@ pub struct ModuleTypesBuilder {
/// The canonicalized and deduplicated set of types we are building.
types: ModuleTypes,

/// The set of trampoline-compatible function types we have already added to
/// `self.types`. We do this additional level of deduping, on top of what
/// `wasmparser` already does, so we can quickly and easily get the
/// trampoline type for a given function type if we've already interned one.
trampoline_types: HashMap<WasmFuncType, ModuleInternedTypeIndex>,

/// A map from already-interned `wasmparser` types to their corresponding
/// Wasmtime type.
wasmparser_to_wasmtime: HashMap<wasmparser::types::CoreTypeId, ModuleInternedTypeIndex>,
Expand All @@ -49,6 +55,7 @@ impl ModuleTypesBuilder {
Self {
validator_id: validator.id(),
types: ModuleTypes::default(),
trampoline_types: HashMap::default(),
wasmparser_to_wasmtime: HashMap::default(),
already_seen: HashMap::default(),
defining_rec_group: None,
Expand Down Expand Up @@ -110,7 +117,63 @@ impl ModuleTypesBuilder {
self.wasm_sub_type_in_rec_group(id, wasm_ty);
}

Ok(self.end_rec_group(rec_group_id))
let rec_group_index = self.end_rec_group(rec_group_id);

// Iterate over all the types we just defined and make sure that every
// function type has an associated trampoline type. This needs to happen
// *after* we finish defining the rec group because we may need to
// intern new function types, which would conflict with the contiguous
// range of type indices we pre-reserved for the rec group elements.
for ty in self.rec_group_elements(rec_group_index) {
if self.types[ty].is_func() {
let trampoline = self.intern_trampoline_type(ty);
self.types.set_trampoline_type(ty, trampoline);
}
}

Ok(rec_group_index)
}

/// Get or create the trampoline function type for the given function
/// type. Returns the interned type index of the trampoline function type.
fn intern_trampoline_type(
&mut self,
for_func_ty: ModuleInternedTypeIndex,
) -> ModuleInternedTypeIndex {
let trampoline = self.types[for_func_ty].unwrap_func().trampoline_type();

if let Some(idx) = self.trampoline_types.get(&trampoline) {
// We've already interned this trampoline type; reuse it.
*idx
} else {
// We have not already interned this trampoline type.
match trampoline {
// The trampoline type is the same as the original function
// type. We can reuse the definition and its index, but still
// need to intern the type into our `trampoline_types` map so we
// can reuse it in the future.
Cow::Borrowed(f) => {
self.trampoline_types.insert(f.clone(), for_func_ty);
for_func_ty
}
// The trampoline type is different from the original function
// type. Define the trampoline type and then intern it in
// `trampoline_types` so we can reuse it in the future.
Cow::Owned(f) => {
let idx = self.types.push(WasmSubType {
composite_type: WasmCompositeType::Func(f.clone()),
});

// The trampoline type is its own trampoline type.
self.types.set_trampoline_type(idx, idx);

let next = self.types.next_ty();
self.types.push_rec_group(idx..next);
self.trampoline_types.insert(f, idx);
idx
}
}
}
}

/// Start defining a recursion group.
Expand Down Expand Up @@ -248,6 +311,19 @@ impl ModuleTypesBuilder {
pub fn wasm_types(&self) -> impl Iterator<Item = (ModuleInternedTypeIndex, &WasmSubType)> {
self.types.wasm_types()
}

/// Get an iterator over all function types and their associated trampoline
/// type.
pub fn trampoline_types(
&self,
) -> impl Iterator<Item = (ModuleInternedTypeIndex, ModuleInternedTypeIndex)> + '_ {
self.types.trampoline_types()
}

/// Get the associated trampoline type for the given function type.
pub fn trampoline_type(&self, ty: ModuleInternedTypeIndex) -> ModuleInternedTypeIndex {
self.types.trampoline_type(ty)
}
}

// Forward the indexing impl to the internal `ModuleTypes`
Expand Down
2 changes: 1 addition & 1 deletion crates/environ/src/module_artifacts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct WasmFunctionInfo {

/// Description of where a function is located in the text section of a
/// compiled image.
#[derive(Copy, Clone, Serialize, Deserialize)]
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct FunctionLoc {
/// The byte offset from the start of the text section where this
/// function starts.
Expand Down
54 changes: 51 additions & 3 deletions crates/environ/src/module_types.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use crate::PrimaryMap;
use core::ops::{Index, Range};
use cranelift_entity::{packed_option::PackedOption, SecondaryMap};
use serde_derive::{Deserialize, Serialize};
use wasmtime_types::{ModuleInternedRecGroupIndex, ModuleInternedTypeIndex, WasmSubType};

/// All types used in a core wasm module.
///
/// At this time this only contains function types. Note, though, that function
/// types are deduplicated within this [`ModuleTypes`].
///
/// Note that accesing this type is primarily done through the `Index`
/// implementations for this type.
#[derive(Default, Serialize, Deserialize)]
pub struct ModuleTypes {
rec_groups: PrimaryMap<ModuleInternedRecGroupIndex, Range<ModuleInternedTypeIndex>>,
wasm_types: PrimaryMap<ModuleInternedTypeIndex, WasmSubType>,
trampoline_types: SecondaryMap<ModuleInternedTypeIndex, PackedOption<ModuleInternedTypeIndex>>,
}

impl ModuleTypes {
Expand Down Expand Up @@ -58,6 +57,55 @@ impl ModuleTypes {
self.wasm_types.push(ty)
}

/// Iterate over the trampoline function types that this module requires.
///
/// Yields pairs of (1) a function type and (2) its associated trampoline
/// type. They might be the same.
///
/// See the docs for `WasmFuncType::trampoline_type` for details on
/// trampoline types.
pub fn trampoline_types(
&self,
) -> impl Iterator<Item = (ModuleInternedTypeIndex, ModuleInternedTypeIndex)> + '_ {
self.trampoline_types
.iter()
.filter_map(|(k, v)| v.expand().map(|v| (k, v)))
}

/// Get the trampoline type for the given function type.
///
/// See the docs for `WasmFuncType::trampoline_type` for details on
/// trampoline types.
pub fn trampoline_type(&self, ty: ModuleInternedTypeIndex) -> ModuleInternedTypeIndex {
debug_assert!(self[ty].is_func());
self.trampoline_types[ty].unwrap()
}
}

/// Methods that only exist for `ModuleTypesBuilder`.
#[cfg(feature = "compile")]
impl ModuleTypes {
/// Associate `trampoline_ty` as the trampoline type for `for_ty`.
///
/// This is really only for use by the `ModuleTypesBuilder`.
pub fn set_trampoline_type(
&mut self,
for_ty: ModuleInternedTypeIndex,
trampoline_ty: ModuleInternedTypeIndex,
) {
use cranelift_entity::packed_option::ReservedValue;

debug_assert!(!for_ty.is_reserved_value());
debug_assert!(!trampoline_ty.is_reserved_value());
debug_assert!(self.wasm_types[for_ty].is_func());
debug_assert!(self.trampoline_types[for_ty].is_none());
debug_assert!(self.wasm_types[trampoline_ty]
.unwrap_func()
.is_trampoline_type());

self.trampoline_types[for_ty] = Some(trampoline_ty).into();
}

/// Adds a new rec group to this interned list of types.
pub fn push_rec_group(
&mut self,
Expand Down
57 changes: 57 additions & 0 deletions crates/types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub use wasmparser;
#[doc(hidden)]
pub use alloc::format as __format;

use alloc::borrow::Cow;
use alloc::boxed::Box;
use core::{fmt, ops::Range};
use cranelift_entity::entity_impl;
Expand Down Expand Up @@ -222,6 +223,20 @@ impl WasmValType {
_ => false,
}
}

fn trampoline_type(&self) -> Self {
match self {
WasmValType::Ref(r) => WasmValType::Ref(WasmRefType {
nullable: true,
heap_type: r.heap_type.top().into(),
}),
WasmValType::I32
| WasmValType::I64
| WasmValType::F32
| WasmValType::F64
| WasmValType::V128 => self.clone(),
}
}
}

/// WebAssembly reference type -- equivalent of `wasmparser`'s RefType
Expand Down Expand Up @@ -601,6 +616,48 @@ impl WasmFuncType {
pub fn non_i31_gc_ref_returns_count(&self) -> usize {
self.non_i31_gc_ref_returns_count
}

/// Is this function type compatible with trampoline usage in Wasmtime?
pub fn is_trampoline_type(&self) -> bool {
self.params().iter().all(|p| *p == p.trampoline_type())
&& self.returns().iter().all(|r| *r == r.trampoline_type())
}

/// Get the version of this function type that is suitable for usage as a
/// trampoline in Wasmtime.
///
/// If this function is suitable for trampoline usage as-is, then a borrowed
/// `Cow` is returned. If it must be tweaked for trampoline usage, then an
/// owned `Cow` is returned.
///
/// ## What is a trampoline type?
///
/// All reference types in parameters and results are mapped to their
/// nullable top type, e.g. `(ref $my_struct_type)` becomes `(ref null
/// any)`.
///
/// This allows us to share trampolines between functions whose signatures
/// both map to the same trampoline type. It also allows the host to satisfy
/// a Wasm module's function import of type `S` with a function of type `T`
/// where `T <: S`, even when the Wasm module never defines the type `T`
/// (and might never even be able to!)
///
/// The flip side is that this adds a constraint to our trampolines: they
/// can only pass references around (e.g. move a reference from one calling
/// convention's location to another's) and may not actually inspect the
/// references themselves (unless the trampolines start doing explicit,
/// fallible downcasts, but if we ever need that, then we might want to
/// redesign this stuff).
pub fn trampoline_type(&self) -> Cow<'_, Self> {
if self.is_trampoline_type() {
return Cow::Borrowed(self);
}

Cow::Owned(Self::new(
self.params().iter().map(|p| p.trampoline_type()).collect(),
self.returns().iter().map(|r| r.trampoline_type()).collect(),
))
}
}

/// Represents storage types introduced in the GC spec for array and struct fields.
Expand Down
40 changes: 23 additions & 17 deletions crates/wasmtime/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,21 +491,25 @@ impl<'a> CompileInputs<'a> {
}
}

for (interned_index, interned_ty) in types.wasm_types() {
if let Some(wasm_func_ty) = interned_ty.as_func() {
self.push_input(move |compiler| {
let trampoline = compiler.compile_wasm_to_native_trampoline(wasm_func_ty)?;
Ok(CompileOutput {
key: CompileKey::wasm_to_native_trampoline(interned_index),
symbol: format!(
"signatures[{}]::wasm_to_native_trampoline",
interned_index.as_u32()
),
function: CompiledFunction::Function(trampoline),
info: None,
})
});
let mut trampoline_types_seen = HashSet::new();
for (_func_type_index, trampoline_type_index) in types.trampoline_types() {
let is_new = trampoline_types_seen.insert(trampoline_type_index);
if !is_new {
continue;
}
let trampoline_func_ty = types[trampoline_type_index].unwrap_func();
self.push_input(move |compiler| {
let trampoline = compiler.compile_wasm_to_native_trampoline(trampoline_func_ty)?;
Ok(CompileOutput {
key: CompileKey::wasm_to_native_trampoline(trampoline_type_index),
symbol: format!(
"signatures[{}]::wasm_to_native_trampoline",
trampoline_type_index.as_u32()
),
function: CompiledFunction::Function(trampoline),
info: None,
})
});
}
}

Expand Down Expand Up @@ -835,17 +839,19 @@ impl FunctionIndices {
})
.collect();

let unique_and_sorted_sigs = translation
let unique_and_sorted_trampoline_sigs = translation
.module
.types
.iter()
.map(|(_, ty)| *ty)
.filter(|idx| types[*idx].is_func())
.map(|idx| types.trampoline_type(idx))
.collect::<BTreeSet<_>>();
let wasm_to_native_trampolines = unique_and_sorted_sigs
let wasm_to_native_trampolines = unique_and_sorted_trampoline_sigs
.iter()
.map(|idx| {
let key = CompileKey::wasm_to_native_trampoline(*idx);
let trampoline = types.trampoline_type(*idx);
let key = CompileKey::wasm_to_native_trampoline(trampoline);
let compiled = wasm_to_native_trampolines[&key];
(*idx, symbol_ids_and_locs[compiled.unwrap_function()].1)
})
Expand Down
4 changes: 2 additions & 2 deletions crates/wasmtime/src/runtime/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1304,8 +1304,8 @@ impl Func {

let sig = self.type_index(store.store_data());
module.runtime_info().wasm_to_native_trampoline(sig).expect(
"must have a wasm-to-native trampoline for this signature if the Wasm \
module is importing a function of this signature",
"if the wasm is importing a function of a given type, it must have the \
type's trampoline",
)
},
native_call: f.as_ref().native_call,
Expand Down
8 changes: 6 additions & 2 deletions crates/wasmtime/src/runtime/instantiate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,14 @@ impl CompiledModule {
/// `VMFuncRef::wasm_call` for `Func::wrap`-style host funcrefs
/// that don't have access to a compiler when created.
pub fn wasm_to_native_trampoline(&self, signature: ModuleInternedTypeIndex) -> &[u8] {
let idx = self
let idx = match self
.wasm_to_native_trampolines
.binary_search_by_key(&signature, |entry| entry.0)
.expect("should have a Wasm-to-native trampline for all signatures");
{
Ok(idx) => idx,
Err(_) => panic!("missing trampoline for {signature:?}"),
};

let (_, loc) = self.wasm_to_native_trampolines[idx];
&self.text()[loc.start as usize..][..loc.length as usize]
}
Expand Down
Loading

0 comments on commit a947340

Please sign in to comment.