Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generics handling for function calls #1215

Merged
merged 7 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions engine/lib/import_thir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1004,8 +1004,8 @@ end) : EXPR = struct
| Float k ->
TFloat
(match k with F16 -> F16 | F32 -> F32 | F64 -> F64 | F128 -> F128)
| Arrow value ->
let ({ inputs; output; _ } : Thir.ty_fn_sig) = value.value in
| Arrow signature | Closure (_, { untupled_sig = signature; _ }) ->
let ({ inputs; output; _ } : Thir.ty_fn_sig) = signature.value in
let inputs =
if List.is_empty inputs then [ U.unit_typ ]
else List.map ~f:(c_ty span) inputs
Expand Down
3 changes: 3 additions & 0 deletions frontend/exporter/src/traits/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ pub fn required_predicates<'tcx>(
.iter()
.map(|(clause, _span)| *clause),
),
// The tuple struct/variant constructor functions inherit the generics and predicates from
// their parents.
Variant | Ctor(..) => return required_predicates(tcx, tcx.parent(def_id)),
// We consider all predicates on traits to be outputs
Trait => None,
// `predicates_defined_on` ICEs on other def kinds.
Expand Down
170 changes: 89 additions & 81 deletions frontend/exporter/src/types/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
use crate::prelude::*;
use crate::sinto_as_usize;
#[cfg(feature = "rustc")]
use rustc_middle::{mir, ty};
#[cfg(feature = "rustc")]
use tracing::trace;

#[derive_group(Serializers)]
Expand Down Expand Up @@ -410,75 +412,86 @@ pub(crate) fn get_function_from_def_id_and_generics<'tcx, S: BaseState<'tcx> + H
(def_id.sinto(s), generics, trait_refs, source)
}

/// Get a `FunOperand` from an `Operand` used in a function call.
/// Return the [DefId] of the function referenced by an operand, with the
/// parameters substitution.
/// The [Operand] comes from a [TerminatorKind::Call].
#[cfg(feature = "rustc")]
fn get_function_from_operand<'tcx, S: UnderOwnerState<'tcx> + HasMir<'tcx>>(
s: &S,
op: &rustc_middle::mir::Operand<'tcx>,
) -> (FunOperand, Vec<GenericArg>, Vec<ImplExpr>, Option<ImplExpr>) {
// Match on the func operand: it should be a constant as we don't support
// closures for now.
use rustc_middle::mir::Operand;
use rustc_middle::ty::TyKind;
let ty = op.ty(&s.mir().local_decls, s.base().tcx);
trace!("type: {:?}", ty);
// If the type of the value is one of the singleton types that corresponds to each function,
// that's enough information.
if let TyKind::FnDef(def_id, generics) = ty.kind() {
let (fun_id, generics, trait_refs, trait_info) =
get_function_from_def_id_and_generics(s, *def_id, *generics);
return (FunOperand::Id(fun_id), generics, trait_refs, trait_info);
}
match op {
Operand::Constant(_) => {
unimplemented!("{:?}", op);
}
Operand::Move(place) => {
// Function pointer. A fn pointer cannot have bound variables or trait references, so
// we don't need to extract generics, trait refs, etc.
let place = place.sinto(s);
(FunOperand::Move(place), Vec::new(), Vec::new(), None)
}
Operand::Copy(_place) => {
unimplemented!("{:?}", op);
}
}
}

#[cfg(feature = "rustc")]
fn translate_terminator_kind_call<'tcx, S: BaseState<'tcx> + HasMir<'tcx> + HasOwnerId>(
s: &S,
terminator: &rustc_middle::mir::TerminatorKind<'tcx>,
) -> TerminatorKind {
if let rustc_middle::mir::TerminatorKind::Call {
let tcx = s.base().tcx;
let mir::TerminatorKind::Call {
func,
args,
destination,
target,
unwind,
call_source,
fn_span,
..
} = terminator
{
let (fun, generics, trait_refs, trait_info) = get_function_from_operand(s, func);
else {
unreachable!()
};

TerminatorKind::Call {
fun,
let ty = func.ty(&s.mir().local_decls, tcx);
let hax_ty: crate::Ty = ty.sinto(s);
let sig = match hax_ty.kind() {
TyKind::Arrow(sig) => sig,
TyKind::Closure(_, args) => &args.untupled_sig,
_ => unreachable!("Attempting to call non-function type: {ty:?}"),
Nadrieril marked this conversation as resolved.
Show resolved Hide resolved
};
let fun_op = if let ty::TyKind::FnDef(def_id, generics) = ty.kind() {
// The type of the value is one of the singleton types that corresponds to each function,
// which is enough information.
let (def_id, generics, trait_refs, trait_info) =
get_function_from_def_id_and_generics(s, *def_id, *generics);
FunOperand::Static {
def_id,
generics,
args: args.sinto(s),
destination: destination.sinto(s),
target: target.sinto(s),
unwind: unwind.sinto(s),
call_source: call_source.sinto(s),
fn_span: fn_span.sinto(s),
trait_refs,
trait_info,
}
} else {
unreachable!()
use mir::Operand;
match func {
Operand::Constant(_) => {
unimplemented!("{:?}", func);
}
Operand::Move(place) => {
// Function pointer or closure.
let place = place.sinto(s);
FunOperand::DynamicMove(place)
}
Operand::Copy(_place) => {
unimplemented!("{:?}", func);
}
}
};

let late_bound_generics = sig
.bound_vars
.iter()
.map(|var| match var {
BoundVariableKind::Region(r) => r,
BoundVariableKind::Ty(..) => {
unreachable!("Found late-bound type variable")
}
BoundVariableKind::Const => {
unreachable!("Found late-bound const variable")
Nadrieril marked this conversation as resolved.
Show resolved Hide resolved
}
})
.map(|_| {
GenericArg::Lifetime(Region {
kind: RegionKind::ReErased,
})
})
.collect();
TerminatorKind::Call {
fun: fun_op,
late_bound_generics,
args: args.sinto(s),
destination: destination.sinto(s),
target: target.sinto(s),
unwind: unwind.sinto(s),
fn_span: fn_span.sinto(s),
}
}

Expand Down Expand Up @@ -562,13 +575,25 @@ pub enum SwitchTargets {
SwitchInt(IntUintTy, Vec<(ScalarInt, BasicBlock)>, BasicBlock),
}

/// A value of type `fn<...> A -> B` that can be called.
#[derive_group(Serializers)]
#[derive(Clone, Debug, JsonSchema)]
pub enum FunOperand {
/// Call to a top-level function designated by its id
Id(DefId),
/// Use of a closure
Move(Place),
/// Call to a statically-known function.
Static {
def_id: DefId,
/// If `Some`, this is a method call on the given trait reference. Otherwise this is a call
/// to a known function.
trait_info: Option<ImplExpr>,
/// If this is a trait method call, this only includes the method generics; the trait
/// generics are included in the `ImplExpr` in `trait_info`.
generics: Vec<GenericArg>,
/// Trait predicates required by the function generics. Like for `generics`, this only
/// includes the predicates required by the method, if applicable.
trait_refs: Vec<ImplExpr>,
},
/// Use of a closure or a function pointer value. Counts as a move from the given place.
DynamicMove(Place),
}

#[derive_group(Serializers)]
Expand Down Expand Up @@ -607,18 +632,16 @@ pub enum TerminatorKind {
)]
Call {
fun: FunOperand,
/// We truncate the substitution so as to only include the arguments
/// relevant to the method (and not the trait) if it is a trait method
/// call. See [ParamsInfo] for the full details.
generics: Vec<GenericArg>,
/// A `FunOperand` is a value of type `fn<...> A -> B`. The generics in `<...>` are called
/// "late-bound" and are instantiated anew at each call site. This list provides the
/// generics used at this call-site. They are all lifetimes and at the time of writing are
/// all erased lifetimes.
late_bound_generics: Vec<GenericArg>,
args: Vec<Spanned<Operand>>,
destination: Place,
target: Option<BasicBlock>,
unwind: UnwindAction,
call_source: CallSource,
fn_span: Span,
trait_refs: Vec<ImplExpr>,
trait_info: Option<ImplExpr>,
},
TailCall {
func: Operand,
Expand Down Expand Up @@ -934,26 +957,12 @@ pub enum AggregateKind {
Option<UserTypeAnnotationIndex>,
Option<FieldIdx>,
),
#[custom_arm(rustc_middle::mir::AggregateKind::Closure(rust_id, generics) => {
let def_id : DefId = rust_id.sinto(s);
// The generics is meant to be converted to a function signature. Note
// that Rustc does its job: the PolyFnSig binds the captured local
// type, regions, etc. variables, which means we can treat the local
// closure like any top-level function.
#[custom_arm(rustc_middle::mir::AggregateKind::Closure(def_id, generics) => {
let closure = generics.as_closure();
let sig = closure.sig().sinto(s);

// Solve the predicates from the parent (i.e., the item which defines the closure).
let tcx = s.base().tcx;
let parent_generics = closure.parent_args();
let parent_generics_ref = tcx.mk_args(parent_generics);
// TODO: does this handle nested closures?
let parent = tcx.generics_of(rust_id).parent.unwrap();
let trait_refs = solve_item_required_traits(s, parent, parent_generics_ref);

AggregateKind::Closure(def_id, parent_generics.sinto(s), trait_refs, sig)
let args = ClosureArgs::sfrom(s, *def_id, closure);
AggregateKind::Closure(def_id.sinto(s), args)
})]
Closure(DefId, Vec<GenericArg>, Vec<ImplExpr>, PolyFnSig),
Closure(DefId, ClosureArgs),
Coroutine(DefId, Vec<GenericArg>),
CoroutineClosure(DefId, Vec<GenericArg>),
RawPtr(Ty, Mutability),
Expand Down Expand Up @@ -1208,7 +1217,6 @@ sinto_todo!(rustc_middle::mir, UserTypeProjection);
sinto_todo!(rustc_middle::mir, MirSource<'tcx>);
sinto_todo!(rustc_middle::mir, CoroutineInfo<'tcx>);
sinto_todo!(rustc_middle::mir, VarDebugInfo<'tcx>);
sinto_todo!(rustc_middle::mir, CallSource);
sinto_todo!(rustc_middle::mir, UnwindTerminateReason);
sinto_todo!(rustc_middle::mir::coverage, CoverageKind);
sinto_todo!(rustc_middle::mir::interpret, ConstAllocation<'a>);
72 changes: 67 additions & 5 deletions frontend/exporter/src/types/new/full_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ pub enum FullDefKind<Body> {
inline: InlineAttr,
#[value(s.base().tcx.constness(s.owner_id()) == rustc_hir::Constness::Const)]
is_const: bool,
#[value(s.base().tcx.fn_sig(s.owner_id()).instantiate_identity().sinto(s))]
#[value(get_method_sig(s).sinto(s))]
sig: PolyFnSig,
#[value(s.owner_id().as_local().map(|ldid| Body::body(ldid, s)))]
body: Option<Body>,
Expand All @@ -271,10 +271,8 @@ pub enum FullDefKind<Body> {
is_const: bool,
#[value({
let fun_type = s.base().tcx.type_of(s.owner_id()).instantiate_identity();
match fun_type.kind() {
ty::TyKind::Closure(_, args) => args.as_closure().sinto(s),
_ => unreachable!(),
}
let ty::TyKind::Closure(_, args) = fun_type.kind() else { unreachable!() };
ClosureArgs::sfrom(s, s.owner_id(), args.as_closure())
})]
args: ClosureArgs,
},
Expand Down Expand Up @@ -769,6 +767,70 @@ where
}
}

/// The signature of a method impl may be a subtype of the one expected from the trait decl, as in
/// the example below. For correctness, we must be able to map from the method generics declared in
/// the trait to the actual method generics. Because this would require type inference, we instead
/// simply return the declared signature. This will cause issues if it is possible to use such a
/// more-specific implementation with its more-specific type, but we have a few other issues with
/// lifetime-generic function pointers anyway so this is unlikely to cause problems.
///
/// ```ignore
/// trait MyCompare<Other>: Sized {
/// fn compare(self, other: Other) -> bool;
/// }
/// impl<'a> MyCompare<&'a ()> for &'a () {
/// // This implementation is more general because it works for non-`'a` refs. Note that only
/// // late-bound vars may differ in this way.
/// // `<&'a () as MyCompare<&'a ()>>::compare` has type `fn<'b>(&'a (), &'b ()) -> bool`,
/// // but type `fn(&'a (), &'a ()) -> bool` was expected from the trait declaration.
/// fn compare<'b>(self, _other: &'b ()) -> bool {
/// true
/// }
/// }
/// ```
#[cfg(feature = "rustc")]
fn get_method_sig<'tcx, S>(s: &S) -> ty::PolyFnSig<'tcx>
where
S: UnderOwnerState<'tcx>,
{
let tcx = s.base().tcx;
let def_id = s.owner_id();
let real_sig = tcx.fn_sig(def_id).instantiate_identity();
let item = tcx.associated_item(def_id);
if !matches!(item.container, ty::AssocItemContainer::ImplContainer) {
return real_sig;
}
let Some(decl_method_id) = item.trait_item_def_id else {
return real_sig;
};
let declared_sig = tcx.fn_sig(decl_method_id);

// TODO(Nadrieril): Temporary hack: if the signatures have the same number of bound vars, we
// keep the real signature. While the declared signature is more correct, it is also less
// normalized and we can't normalize without erasing regions but regions are crucial in
// function signatures. Hence we cheat here, until charon gains proper normalization
// capabilities.
if declared_sig.skip_binder().bound_vars().len() == real_sig.bound_vars().len() {
return real_sig;
}

let impl_def_id = item.container_id(tcx);
// The trait predicate that is implemented by the surrounding impl block.
let implemented_trait_ref = tcx
.impl_trait_ref(impl_def_id)
.unwrap()
.instantiate_identity();
// Construct arguments for the declared method generics in the context of the implemented
// method generics.
let impl_args = ty::GenericArgs::identity_for_item(tcx, def_id);
let decl_args = impl_args.rebase_onto(tcx, impl_def_id, implemented_trait_ref.args);
let sig = declared_sig.instantiate(tcx, decl_args);
// Avoids accidentally using the same lifetime name twice in the same scope
// (once in impl parameters, second in the method declaration late-bound vars).
let sig = tcx.anonymize_bound_vars(sig);
sig
}

#[cfg(feature = "rustc")]
fn get_ctor_contents<'tcx, S, Body>(s: &S, ctor_of: CtorOf) -> FullDefKind<Body>
where
Expand Down
Loading
Loading