From c3f7fcdfd6a9fc46a005704d8855eb840f16ef69 Mon Sep 17 00:00:00 2001 From: Lucas Franceschino Date: Tue, 30 Jan 2024 15:26:28 +0100 Subject: [PATCH 1/4] fix(frontend): fix select_trait_candidate --- frontend/exporter/src/traits.rs | 155 ++++++++++++++++++++------------ 1 file changed, 96 insertions(+), 59 deletions(-) diff --git a/frontend/exporter/src/traits.rs b/frontend/exporter/src/traits.rs index 742a1bb51..2ba5d0141 100644 --- a/frontend/exporter/src/traits.rs +++ b/frontend/exporter/src/traits.rs @@ -236,11 +236,7 @@ impl<'tcx> IntoImplExpr<'tcx> for rustc_middle::ty::PolyTraitRef<'tcx> { param_env: rustc_middle::ty::ParamEnv<'tcx>, ) -> ImplExpr { use rustc_trait_selection::traits::*; - let Some(impl_source) = select_trait_candidate(s, param_env, *self) else { - report!(warn, s, "Warning: the frontend could not resolve using `select_trait_candidate`.\nThis is a bug documented in `https://github.com/hacspec/hax/issues/416`.\nCurrently, this bug is non-fatal at the exporter level."); - return ImplExprAtom::Todo(format!("impl_expr failed on {:#?}", self)).into(); - }; - match impl_source { + match select_trait_candidate(s, param_env, *self) { ImplSource::UserDefined(ImplSourceUserDefinedData { impl_def_id, substs, @@ -310,70 +306,111 @@ pub fn clause_id_of_predicate(predicate: rustc_middle::ty::Predicate) -> u64 { predicate.hash(&mut s); s.finish() } - -/// Adapted from [rustc_trait_selection::traits::SelectionContext::select]: -/// we want to preserve the nested obligations to resolve them afterwards. -/// -/// Example: -/// ======== -/// ```text -/// struct Wrapper { -/// x: T, -/// } -/// -/// impl ToU64 for Wrapper { -/// fn to_u64(self) -> u64 { -/// self.x.to_u64() -/// } -/// } -/// -/// fn h(x: Wrapper) -> u64 { -/// x.to_u64() -/// } -/// ``` -/// -/// When resolving the trait for `x.to_u64()` in `h`, we get that it uses the -/// implementation for `Wrapper`. But we also need to know the obligation generated -/// for `Wrapper` (in this case: `u64 : ToU64`) and resolve it. -/// -/// TODO: returns an Option for now, `None` means we hit the indexing -/// bug (see ). #[tracing::instrument(level = "trace", skip(s))] pub fn select_trait_candidate<'tcx, S: UnderOwnerState<'tcx>>( s: &S, param_env: rustc_middle::ty::ParamEnv<'tcx>, trait_ref: rustc_middle::ty::PolyTraitRef<'tcx>, -) -> Option> { - use rustc_infer::infer::TyCtxtInferExt; - use rustc_trait_selection::traits::{Obligation, ObligationCause, SelectionContext}; +) -> rustc_trait_selection::traits::Selection<'tcx> { let tcx = s.base().tcx; - let trait_ref = tcx - .try_normalize_erasing_regions(param_env, trait_ref) - .unwrap_or(trait_ref); - - // Do the initial selection for the obligation. This yields the - // shallow result we are looking for -- that is, what specific impl. - let infcx = tcx.infer_ctxt().ignoring_regions().build(); - let mut selcx = SelectionContext::new(&infcx); - - let obligation_cause = ObligationCause::dummy(); - let obligation = Obligation::new(tcx, obligation_cause, param_env, trait_ref); - - let selection = { - use std::panic; - panic::set_hook(Box::new(|_info| {})); - let result = panic::catch_unwind(panic::AssertUnwindSafe(|| selcx.select(&obligation))); - let _ = panic::take_hook(); - result - }; - match selection { - Ok(Ok(Some(selection))) => Some(infcx.resolve_vars_if_possible(selection)), - Ok(error) => fatal!( + match copy_paste_from_rustc::codegen_select_candidate(tcx, (param_env, trait_ref)) { + Ok(selection) => selection, + Err(error) => fatal!( s, "Cannot hanlde error `{:?}` selecting `{:?}`", error, trait_ref ), - Err(_) => None, + } +} + +pub mod copy_paste_from_rustc { + use rustc_infer::infer::TyCtxtInferExt; + use rustc_infer::traits::{FulfillmentErrorCode, TraitEngineExt as _}; + use rustc_middle::traits::{CodegenObligationError, DefiningAnchor}; + use rustc_middle::ty::{self, TyCtxt}; + use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt; + use rustc_trait_selection::traits::{ + ImplSource, Obligation, ObligationCause, SelectionContext, TraitEngine, TraitEngineExt, + Unimplemented, + }; + + /// Attempts to resolve an obligation to an `ImplSource`. The result is + /// a shallow `ImplSource` resolution, meaning that we do not + /// (necessarily) resolve all nested obligations on the impl. Note + /// that type check should guarantee to us that all nested + /// obligations *could be* resolved if we wanted to. + /// + /// This also expects that `trait_ref` is fully normalized. + pub fn codegen_select_candidate<'tcx>( + tcx: TyCtxt<'tcx>, + (param_env, trait_ref): (ty::ParamEnv<'tcx>, ty::PolyTraitRef<'tcx>), + ) -> Result, CodegenObligationError> { + // We expect the input to be fully normalized. + debug_assert_eq!( + trait_ref, + tcx.normalize_erasing_regions(param_env, trait_ref) + ); + + // Do the initial selection for the obligation. This yields the + // shallow result we are looking for -- that is, what specific impl. + let infcx = tcx + .infer_ctxt() + .ignoring_regions() + .with_opaque_type_inference(DefiningAnchor::Bubble) + .build(); + //~^ HACK `Bubble` is required for + // this test to pass: type-alias-impl-trait/assoc-projection-ice.rs + let mut selcx = SelectionContext::new(&infcx); + + let obligation_cause = ObligationCause::dummy(); + let obligation = Obligation::new(tcx, obligation_cause, param_env, trait_ref); + + let selection = match selcx.select(&obligation) { + Ok(Some(selection)) => selection, + Ok(None) => return Err(CodegenObligationError::Ambiguity), + Err(Unimplemented) => return Err(CodegenObligationError::Unimplemented), + Err(e) => { + panic!( + "Encountered error `{:?}` selecting `{:?}` during codegen", + e, trait_ref + ) + } + }; + + // Currently, we use a fulfillment context to completely resolve + // all nested obligations. This is because they can inform the + // inference of the impl's type parameters. + let mut fulfill_cx = >::new(tcx); + let impl_source = selection.map(|predicate| { + fulfill_cx.register_predicate_obligation(&infcx, predicate.clone()); + predicate + }); + + // In principle, we only need to do this so long as `impl_source` + // contains unbound type parameters. It could be a slight + // optimization to stop iterating early. + let errors = fulfill_cx.select_all_or_error(&infcx); + if !errors.is_empty() { + // `rustc_monomorphize::collector` assumes there are no type errors. + // Cycle errors are the only post-monomorphization errors possible; emit them now so + // `rustc_ty_utils::resolve_associated_item` doesn't return `None` post-monomorphization. + for err in errors { + if let FulfillmentErrorCode::CodeCycle(cycle) = err.code { + infcx.err_ctxt().report_overflow_obligation_cycle(&cycle); + } + } + return Err(CodegenObligationError::FulfillmentError); + } + + let impl_source = infcx.resolve_vars_if_possible(impl_source); + let impl_source = infcx.tcx.erase_regions(impl_source); + + // Opaque types may have gotten their hidden types constrained, but we can ignore them safely + // as they will get constrained elsewhere, too. + // (ouz-a) This is required for `type-alias-impl-trait/assoc-projection-ice.rs` to pass + let _ = infcx.take_opaque_types(); + + Ok(impl_source) } } From 194e255afacaae9d3fe98b68588cc5c4c25ffae3 Mon Sep 17 00:00:00 2001 From: Lucas Franceschino Date: Wed, 31 Jan 2024 16:30:11 +0100 Subject: [PATCH 2/4] feat(frontend/traits): misc improvements with Son Co-authored-by: Son Ho --- frontend/exporter/src/traits.rs | 132 +++++++++++++++++++++++++------- 1 file changed, 105 insertions(+), 27 deletions(-) diff --git a/frontend/exporter/src/traits.rs b/frontend/exporter/src/traits.rs index 2ba5d0141..42ed1d9d4 100644 --- a/frontend/exporter/src/traits.rs +++ b/frontend/exporter/src/traits.rs @@ -6,8 +6,15 @@ use crate::prelude::*; Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, )] pub enum ImplExprPathChunk { - AssocItem(AssocItem, TraitPredicate), - Parent(TraitPredicate), + AssocItem { + item: AssocItem, + predicate: TraitPredicate, + index: usize, + }, + Parent { + predicate: TraitPredicate, + index: usize, + }, } #[derive( @@ -20,6 +27,7 @@ pub enum ImplExprAtom { }, LocalBound { clause_id: u64, + r#trait: Binder, path: Vec, }, SelfImpl, @@ -33,6 +41,14 @@ pub enum ImplExprAtom { Builtin { r#trait: TraitRef, }, + FnPointer { + fn_ty: Box, + }, + Closure { + closure_def_id: DefId, + parent_substs: Vec, + signature: Box, + }, Todo(String), } @@ -40,8 +56,9 @@ pub enum ImplExprAtom { Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, )] pub struct ImplExpr { - r#impl: ImplExprAtom, - args: Box>, + pub r#impl: ImplExprAtom, + pub args: Box>, + pub r#trait: TraitRef, } mod search_clause { @@ -65,8 +82,15 @@ mod search_clause { #[derive(Clone, Debug)] pub enum PathChunk<'tcx> { - AssocItem(AssocItem, TraitPredicate<'tcx>), - Parent(TraitPredicate<'tcx>), + AssocItem { + item: AssocItem, + predicate: TraitPredicate<'tcx>, + index: usize, + }, + Parent { + predicate: TraitPredicate<'tcx>, + index: usize, + }, } pub type Path<'tcx> = Vec>; @@ -116,18 +140,23 @@ mod search_clause { #[extension_traits::extension(pub trait TraitPredicateExt)] impl<'tcx, S: UnderOwnerState<'tcx>> TraitPredicate<'tcx> { - fn parents_trait_predicates(self, s: &S) -> Vec> { + fn parents_trait_predicates(self, s: &S) -> Vec<(usize, TraitPredicate<'tcx>)> { let tcx = s.base().tcx; let predicates = tcx .predicates_defined_on_or_above(self.def_id()) .into_iter() .map(|apred| apred.predicate); - predicates_to_trait_predicates(tcx, predicates, self.trait_ref.substs).collect() + predicates_to_trait_predicates(tcx, predicates, self.trait_ref.substs) + .enumerate() + .collect() } fn associated_items_trait_predicates( self, s: &S, - ) -> Vec<(AssocItem, subst::EarlyBinder>>)> { + ) -> Vec<( + AssocItem, + subst::EarlyBinder)>>, + )> { let tcx = s.base().tcx; tcx.associated_items(self.def_id()) .in_definition_order() @@ -140,6 +169,7 @@ mod search_clause { predicates.into_iter(), self.trait_ref.substs, ) + .enumerate() .collect() }); (item, bounds) @@ -174,14 +204,33 @@ mod search_clause { } self.parents_trait_predicates(s) .into_iter() - .filter_map(|p| recurse(p).map(|path| cons(PathChunk::Parent(p), path))) + .filter_map(|(index, p)| { + recurse(p).map(|path| { + cons( + PathChunk::Parent { + predicate: p, + index, + }, + path, + ) + }) + }) .max_by_key(|path| path.len()) .or_else(|| { self.associated_items_trait_predicates(s) .into_iter() .filter_map(|(item, binder)| { - binder.skip_binder().into_iter().find_map(|p| { - recurse(p).map(|path| cons(PathChunk::AssocItem(item, p), path)) + binder.skip_binder().into_iter().find_map(|(index, p)| { + recurse(p).map(|path| { + cons( + PathChunk::AssocItem { + item, + predicate: p, + index, + }, + path, + ) + }) }) }) .max_by_key(|path| path.len()) @@ -191,18 +240,14 @@ mod search_clause { } impl ImplExprAtom { - fn with_args(self, args: Vec) -> ImplExpr { + fn with_args(self, args: Vec, r#trait: TraitRef) -> ImplExpr { ImplExpr { r#impl: self, args: Box::new(args), + r#trait, } } } -impl From for ImplExpr { - fn from(implem: ImplExprAtom) -> ImplExpr { - implem.with_args(vec![]) - } -} fn impl_exprs<'tcx, S: UnderOwnerState<'tcx>>( s: &S, @@ -236,6 +281,8 @@ impl<'tcx> IntoImplExpr<'tcx> for rustc_middle::ty::PolyTraitRef<'tcx> { param_env: rustc_middle::ty::ParamEnv<'tcx>, ) -> ImplExpr { use rustc_trait_selection::traits::*; + let trait_ref: Binder = self.sinto(s); + let trait_ref = trait_ref.value; match select_trait_candidate(s, param_env, *self) { ImplSource::UserDefined(ImplSourceUserDefinedData { impl_def_id, @@ -245,7 +292,7 @@ impl<'tcx> IntoImplExpr<'tcx> for rustc_middle::ty::PolyTraitRef<'tcx> { id: impl_def_id.sinto(s), generics: substs.sinto(s), } - .with_args(impl_exprs(s, &nested)), + .with_args(impl_exprs(s, &nested), trait_ref), ImplSource::Param(nested, _constness) => { use search_clause::TraitPredicateExt; let tcx = s.base().tcx; @@ -262,38 +309,70 @@ impl<'tcx> IntoImplExpr<'tcx> for rustc_middle::ty::PolyTraitRef<'tcx> { .map(|path| (apred, path)) }) else { supposely_unreachable_fatal!(s, "ImplExprPredNotFound"; { - self, nested, predicates + self, nested, predicates, trait_ref }) }; + use rustc_middle::ty::ToPolyTraitRef; if apred.is_extra_self_predicate { if !path.is_empty() { supposely_unreachable_fatal!(s[apred.span], "SelfWithNonEmptyPath"; { self, apred, path }); } - ImplExprAtom::SelfImpl.with_args(vec![]) + ImplExprAtom::SelfImpl.with_args(vec![], trait_ref) } else { let clause_id: u64 = clause_id_of_predicate(apred.predicate); + let r#trait = apred + .predicate + .to_opt_poly_trait_pred() + .s_unwrap(s) + .to_poly_trait_ref() + .sinto(s); ImplExprAtom::LocalBound { clause_id, + r#trait, path: path.sinto(s), } - .with_args(impl_exprs(s, &nested)) + .with_args(impl_exprs(s, &nested), trait_ref) + } + } + // Happens when we use a function pointer as an object implementing + // a trait like `FnMut` + ImplSource::FnPointer(rustc_trait_selection::traits::ImplSourceFnPointerData { + fn_ty, + nested, + }) => ImplExprAtom::FnPointer { + fn_ty: fn_ty.sinto(s), + } + .with_args(impl_exprs(s, &nested), trait_ref), + ImplSource::Closure(rustc_trait_selection::traits::ImplSourceClosureData { + closure_def_id, + substs, + nested, + }) => { + let substs = substs.as_closure(); + let parent_substs = substs.parent_substs().sinto(s); + let signature = Box::new(substs.sig().sinto(s)); + ImplExprAtom::Closure { + closure_def_id: closure_def_id.sinto(s), + parent_substs, + signature, } + .with_args(impl_exprs(s, &nested), trait_ref) } ImplSource::Object(data) => ImplExprAtom::Dyn { r#trait: data.upcast_trait_ref.skip_binder().sinto(s), } - .with_args(impl_exprs(s, &data.nested)), + .with_args(impl_exprs(s, &data.nested), trait_ref), ImplSource::Builtin(x) => ImplExprAtom::Builtin { r#trait: self.skip_binder().sinto(s), } - .with_args(impl_exprs(s, &x.nested)), + .with_args(impl_exprs(s, &x.nested), trait_ref), x => ImplExprAtom::Todo(format!( "ImplExprAtom::Todo(see https://github.com/hacspec/hax/issues/381) {:#?}\n\n{:#?}", x, self )) - .into(), + .with_args(vec![], trait_ref), } } } @@ -331,8 +410,7 @@ pub mod copy_paste_from_rustc { use rustc_middle::ty::{self, TyCtxt}; use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt; use rustc_trait_selection::traits::{ - ImplSource, Obligation, ObligationCause, SelectionContext, TraitEngine, TraitEngineExt, - Unimplemented, + Obligation, ObligationCause, SelectionContext, TraitEngine, TraitEngineExt, Unimplemented, }; /// Attempts to resolve an obligation to an `ImplSource`. The result is From 181375fcf479d1a75cbf4b6ca435449dc19480d7 Mon Sep 17 00:00:00 2001 From: Lucas Franceschino Date: Wed, 31 Jan 2024 16:44:16 +0100 Subject: [PATCH 3/4] feat(frontend/traits): propagate improvements in engine --- engine/lib/ast.ml | 3 +++ engine/lib/import_thir.ml | 8 +++++--- engine/lib/subtype.ml | 2 ++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/engine/lib/ast.ml b/engine/lib/ast.ml index ff6fd5a8b..6fbbeb78c 100644 --- a/engine/lib/ast.ml +++ b/engine/lib/ast.ml @@ -268,6 +268,9 @@ functor | ImplApp of { impl : impl_expr; args : impl_expr list } | Dyn of trait_ref | Builtin of trait_ref + | FnPointer of ty + (* The `IE` suffix is there because visitors conflicts...... *) + | ClosureIE of todo and trait_ref = { trait : concrete_ident; args : generic_value list } diff --git a/engine/lib/import_thir.ml b/engine/lib/import_thir.ml index 1db362c78..567bc0fee 100644 --- a/engine/lib/import_thir.ml +++ b/engine/lib/import_thir.ml @@ -922,18 +922,18 @@ end) : EXPR = struct let trait = Concrete_ident.of_def_id Impl id in let args = List.map ~f:(c_generic_value span) generics in Concrete { trait; args } - | LocalBound { clause_id; path } -> + | LocalBound { clause_id; path; _ } -> let init = LocalBound { id = clause_id } in let f (impl : impl_expr) (chunk : Thir.impl_expr_path_chunk) = match chunk with - | AssocItem (item, { trait_ref; _ }) -> + | AssocItem { item; predicate = { trait_ref; _ }; _ } -> let trait = c_trait_ref span trait_ref in let kind : Concrete_ident.Kind.t = match item.kind with Const | Fn -> Value | Type -> Type in let item = Concrete_ident.of_def_id kind item.def_id in Projection { impl; trait; item } - | Parent { trait_ref; _ } -> + | Parent { predicate = { trait_ref; _ }; _ } -> let trait = c_trait_ref span trait_ref in Parent { impl; trait } in @@ -941,6 +941,8 @@ end) : EXPR = struct | Dyn { trait } -> Dyn (c_trait_ref span trait) | SelfImpl -> Self | Builtin { trait } -> Builtin (c_trait_ref span trait) + | FnPointer { fn_ty } -> FnPointer (c_ty span fn_ty) + | Closure _ as x -> ClosureIE ([%show: Thir.impl_expr_atom] x) | Todo str -> failwith @@ "impl_expr_atom: Todo " ^ str and c_generic_value (span : Thir.span) (ty : Thir.generic_arg) : generic_value diff --git a/engine/lib/subtype.ml b/engine/lib/subtype.ml index 86dfd3a0c..290754012 100644 --- a/engine/lib/subtype.ml +++ b/engine/lib/subtype.ml @@ -69,6 +69,8 @@ struct } | Dyn tr -> Dyn (dtrait_ref span tr) | Builtin tr -> Builtin (dtrait_ref span tr) + | ClosureIE todo -> ClosureIE todo + | FnPointer ty -> FnPointer (dty span ty) and dgeneric_value (span : span) (generic_value : A.generic_value) : B.generic_value = From 8d18742c994cf3ab6f5583c20585158c02380066 Mon Sep 17 00:00:00 2001 From: Lucas Franceschino Date: Wed, 31 Jan 2024 16:49:42 +0100 Subject: [PATCH 4/4] feat(frontend/traits): propagate improvements in tests --- .../toolchain__traits into-fstar.snap | 23 +++++++++++++++++++ tests/traits/src/lib.rs | 8 +++++++ 2 files changed, 31 insertions(+) diff --git a/test-harness/src/snapshots/toolchain__traits into-fstar.snap b/test-harness/src/snapshots/toolchain__traits into-fstar.snap index ed926974c..aab9040a4 100644 --- a/test-harness/src/snapshots/toolchain__traits into-fstar.snap +++ b/test-harness/src/snapshots/toolchain__traits into-fstar.snap @@ -58,6 +58,29 @@ class t_Foo (v_Self: Type) = { f_method_f:v_Self -> Prims.unit } +let closure_impl_expr + (#v_I: Type) + (#[FStar.Tactics.Typeclasses.tcresolve ()] i0: Core.Marker.t_Sized v_I) + (#[FStar.Tactics.Typeclasses.tcresolve ()] i1: Core.Iter.Traits.Iterator.t_Iterator v_I) + (it: v_I) + : Alloc.Vec.t_Vec Prims.unit Alloc.Alloc.t_Global = + Core.Iter.Traits.Iterator.f_collect (Core.Iter.Traits.Iterator.f_map it (fun x -> x) + <: + Core.Iter.Adapters.Map.t_Map v_I (Prims.unit -> Prims.unit)) + +let closure_impl_expr_fngen + (#v_I #v_F: Type) + (#[FStar.Tactics.Typeclasses.tcresolve ()] i0: Core.Marker.t_Sized v_I) + (#[FStar.Tactics.Typeclasses.tcresolve ()] i1: Core.Marker.t_Sized v_F) + (#[FStar.Tactics.Typeclasses.tcresolve ()] i2: Core.Iter.Traits.Iterator.t_Iterator v_I) + (#[FStar.Tactics.Typeclasses.tcresolve ()] i3: Core.Ops.Function.t_FnMut v_F Prims.unit) + (it: v_I) + (f: v_F) + : Alloc.Vec.t_Vec Prims.unit Alloc.Alloc.t_Global = + Core.Iter.Traits.Iterator.f_collect (Core.Iter.Traits.Iterator.f_map it f + <: + Core.Iter.Adapters.Map.t_Map v_I v_F) + let f (#v_T: Type) (#[FStar.Tactics.Typeclasses.tcresolve ()] i0: Core.Marker.t_Sized v_T) diff --git a/tests/traits/src/lib.rs b/tests/traits/src/lib.rs index 6b993a484..f0238d71c 100644 --- a/tests/traits/src/lib.rs +++ b/tests/traits/src/lib.rs @@ -48,3 +48,11 @@ impl<'a> Struct { x.bar() } } + +pub fn closure_impl_expr>(it: I) -> Vec<()> { + it.map(|x| x).collect() +} + +pub fn closure_impl_expr_fngen, F: FnMut(()) -> ()>(it: I, f: F) -> Vec<()> { + it.map(f).collect() +}