1use hax_frontend_exporter_options::BoundsOptions;
6use itertools::Itertools;
7use std::collections::{HashMap, hash_map::Entry};
8
9use rustc_hir::def::DefKind;
10use rustc_hir::def_id::DefId;
11use rustc_middle::traits::CodegenObligationError;
12use rustc_middle::ty::{self, *};
13use rustc_trait_selection::traits::ImplSource;
14
15use crate::{self_predicate, traits::utils::erase_and_norm};
16
17use super::utils::{ToPolyTraitRef, implied_predicates, normalize_bound_val, required_predicates};
18
19#[derive(Debug, Clone)]
20pub enum PathChunk<'tcx> {
21 AssocItem {
22 item: AssocItem,
23 generic_args: GenericArgsRef<'tcx>,
25 impl_exprs: Vec<ImplExpr<'tcx>>,
33 predicate: PolyTraitPredicate<'tcx>,
35 index: usize,
37 },
38 Parent {
39 predicate: PolyTraitPredicate<'tcx>,
41 index: usize,
43 },
44}
45pub type Path<'tcx> = Vec<PathChunk<'tcx>>;
46
47#[derive(Debug, Clone)]
48pub enum ImplExprAtom<'tcx> {
49 Concrete {
51 def_id: DefId,
52 generics: GenericArgsRef<'tcx>,
53 impl_exprs: Vec<ImplExpr<'tcx>>,
55 },
56 LocalBound {
58 predicate: Predicate<'tcx>,
59 index: usize,
62 r#trait: PolyTraitRef<'tcx>,
63 path: Path<'tcx>,
64 },
65 SelfImpl {
67 r#trait: PolyTraitRef<'tcx>,
68 path: Path<'tcx>,
69 },
70 Dyn,
77 Drop(DropData<'tcx>),
82 Builtin {
86 r#trait: PolyTraitRef<'tcx>,
87 impl_exprs: Vec<ImplExpr<'tcx>>,
91 types: Vec<(DefId, Ty<'tcx>)>,
93 },
94 Error(String),
96}
97
98#[derive(Debug, Clone)]
99pub enum DropData<'tcx> {
100 Noop,
102 Implicit,
107 Glue {
111 ty: Ty<'tcx>,
113 impl_exprs: Vec<ImplExpr<'tcx>>,
116 },
117}
118
119#[derive(Clone, Debug)]
120pub struct ImplExpr<'tcx> {
121 pub r#trait: PolyTraitRef<'tcx>,
123 pub r#impl: ImplExprAtom<'tcx>,
125}
126
127#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
130pub enum BoundPredicateOrigin {
131 SelfPred,
134 Item(usize),
137}
138
139#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
140pub struct AnnotatedTraitPred<'tcx> {
141 pub origin: BoundPredicateOrigin,
142 pub clause: PolyTraitPredicate<'tcx>,
143}
144
145fn initial_search_predicates<'tcx>(
149 tcx: TyCtxt<'tcx>,
150 def_id: rustc_span::def_id::DefId,
151 options: BoundsOptions,
152) -> Vec<AnnotatedTraitPred<'tcx>> {
153 fn acc_predicates<'tcx>(
154 tcx: TyCtxt<'tcx>,
155 def_id: rustc_span::def_id::DefId,
156 options: BoundsOptions,
157 predicates: &mut Vec<AnnotatedTraitPred<'tcx>>,
158 pred_id: &mut usize,
159 ) {
160 let next_item_origin = |pred_id: &mut usize| {
161 let origin = BoundPredicateOrigin::Item(*pred_id);
162 *pred_id += 1;
163 origin
164 };
165 use DefKind::*;
166 match tcx.def_kind(def_id) {
167 AssocTy | AssocFn | AssocConst | Closure | Ctor(..) | Variant => {
169 let parent = tcx.parent(def_id);
170 acc_predicates(tcx, parent, options, predicates, pred_id);
171 }
172 Trait | TraitAlias => {
173 let self_pred = self_predicate(tcx, def_id).upcast(tcx);
174 predicates.push(AnnotatedTraitPred {
175 origin: BoundPredicateOrigin::SelfPred,
176 clause: self_pred,
177 })
178 }
179 _ => {}
180 }
181 predicates.extend(
182 required_predicates(tcx, def_id, options)
183 .predicates
184 .iter()
185 .map(|(clause, _span)| *clause)
186 .filter_map(|clause| {
187 clause.as_trait_clause().map(|clause| AnnotatedTraitPred {
188 origin: next_item_origin(pred_id),
189 clause,
190 })
191 }),
192 );
193 }
194
195 let mut predicates = vec![];
196 acc_predicates(tcx, def_id, options, &mut predicates, &mut 0);
197 predicates
198}
199
200#[tracing::instrument(level = "trace", skip(tcx))]
201fn parents_trait_predicates<'tcx>(
202 tcx: TyCtxt<'tcx>,
203 pred: PolyTraitPredicate<'tcx>,
204 options: BoundsOptions,
205) -> Vec<PolyTraitPredicate<'tcx>> {
206 let self_trait_ref = pred.to_poly_trait_ref();
207 implied_predicates(tcx, pred.def_id(), options)
208 .predicates
209 .iter()
210 .map(|(clause, _span)| *clause)
211 .map(|clause| clause.instantiate_supertrait(tcx, self_trait_ref))
214 .filter_map(|pred| pred.as_trait_clause())
215 .collect()
216}
217
218#[derive(Debug, Clone)]
221struct Candidate<'tcx> {
222 path: Path<'tcx>,
223 pred: PolyTraitPredicate<'tcx>,
224 origin: AnnotatedTraitPred<'tcx>,
225}
226
227impl<'tcx> Candidate<'tcx> {
228 fn into_impl_expr(self, tcx: TyCtxt<'tcx>) -> ImplExprAtom<'tcx> {
229 let path = self.path;
230 let r#trait = self.origin.clause.to_poly_trait_ref();
231 match self.origin.origin {
232 BoundPredicateOrigin::SelfPred => ImplExprAtom::SelfImpl { r#trait, path },
233 BoundPredicateOrigin::Item(index) => ImplExprAtom::LocalBound {
234 predicate: self.origin.clause.upcast(tcx),
235 index,
236 r#trait,
237 path,
238 },
239 }
240 }
241}
242
243pub struct PredicateSearcher<'tcx> {
245 tcx: TyCtxt<'tcx>,
246 typing_env: rustc_middle::ty::TypingEnv<'tcx>,
247 candidates: HashMap<PolyTraitPredicate<'tcx>, Candidate<'tcx>>,
249 options: BoundsOptions,
250}
251
252impl<'tcx> PredicateSearcher<'tcx> {
253 pub fn new_for_owner(tcx: TyCtxt<'tcx>, owner_id: DefId, options: BoundsOptions) -> Self {
255 let mut out = Self {
256 tcx,
257 typing_env: TypingEnv {
258 param_env: tcx.param_env(owner_id),
259 typing_mode: TypingMode::PostAnalysis,
260 },
261 candidates: Default::default(),
262 options,
263 };
264 out.extend(
265 initial_search_predicates(tcx, owner_id, options)
266 .into_iter()
267 .map(|clause| Candidate {
268 path: vec![],
269 pred: clause.clause,
270 origin: clause,
271 }),
272 );
273 out
274 }
275
276 fn extend(&mut self, candidates: impl IntoIterator<Item = Candidate<'tcx>>) {
279 let tcx = self.tcx;
280 let mut new_candidates = Vec::new();
282 for mut candidate in candidates {
283 candidate.pred = normalize_bound_val(tcx, self.typing_env, candidate.pred);
285 if let Entry::Vacant(entry) = self.candidates.entry(candidate.pred) {
286 entry.insert(candidate.clone());
287 new_candidates.push(candidate);
288 }
289 }
290 if !new_candidates.is_empty() {
291 self.extend_parents(new_candidates);
292 }
293 }
294
295 fn extend_parents(&mut self, new_candidates: Vec<Candidate<'tcx>>) {
299 let tcx = self.tcx;
300 let options = self.options;
303 self.extend(new_candidates.into_iter().flat_map(|candidate| {
304 parents_trait_predicates(tcx, candidate.pred, options)
305 .into_iter()
306 .enumerate()
307 .map(move |(index, parent_pred)| {
308 let mut parent_candidate = Candidate {
309 pred: parent_pred,
310 path: candidate.path.clone(),
311 origin: candidate.origin,
312 };
313 parent_candidate.path.push(PathChunk::Parent {
314 predicate: parent_pred,
315 index,
316 });
317 parent_candidate
318 })
319 }));
320 }
321
322 fn add_associated_type_refs(
324 &mut self,
325 ty: Binder<'tcx, Ty<'tcx>>,
326 warn: &impl Fn(&str),
328 ) -> Result<(), String> {
329 let tcx = self.tcx;
330 let TyKind::Alias(AliasTyKind::Projection, alias_ty) = ty.skip_binder().kind() else {
332 return Ok(());
333 };
334 let (trait_ref, item_args) = alias_ty.trait_ref_and_own_args(tcx);
335 let item_args = tcx.mk_args(item_args);
336 let trait_ref = ty.rebind(trait_ref).upcast(tcx);
337
338 let Some(trait_candidate) = self.resolve_local(trait_ref, warn)? else {
341 return Ok(());
342 };
343
344 let item_bounds = implied_predicates(tcx, alias_ty.def_id, self.options)
346 .predicates
347 .iter()
348 .map(|(clause, _span)| *clause)
349 .filter_map(|pred| pred.as_trait_clause())
350 .map(|pred| EarlyBinder::bind(pred).instantiate(tcx, alias_ty.args))
352 .enumerate();
353
354 let nested_impl_exprs =
356 self.resolve_item_required_predicates(alias_ty.def_id, alias_ty.args, warn)?;
357
358 self.extend(item_bounds.map(|(index, pred)| {
360 let mut candidate = Candidate {
361 path: trait_candidate.path.clone(),
362 pred,
363 origin: trait_candidate.origin,
364 };
365 candidate.path.push(PathChunk::AssocItem {
366 item: tcx.associated_item(alias_ty.def_id),
367 generic_args: item_args,
368 impl_exprs: nested_impl_exprs.clone(),
369 predicate: pred,
370 index,
371 });
372 candidate
373 }));
374
375 Ok(())
376 }
377
378 fn resolve_local(
381 &mut self,
382 target: PolyTraitPredicate<'tcx>,
383 warn: &impl Fn(&str),
385 ) -> Result<Option<Candidate<'tcx>>, String> {
386 tracing::trace!("Looking for {target:?}");
387
388 let ret = self.candidates.get(&target).cloned();
390 if ret.is_some() {
391 return Ok(ret);
392 }
393
394 self.add_associated_type_refs(target.self_ty(), warn)?;
396
397 let ret = self.candidates.get(&target).cloned();
398 if ret.is_none() {
399 tracing::trace!(
400 "Couldn't find {target:?} in: [\n{}]",
401 self.candidates
402 .iter()
403 .map(|(_, c)| format!(" - {:?}\n", c.pred))
404 .join("")
405 );
406 }
407 Ok(ret)
408 }
409
410 #[tracing::instrument(level = "trace", skip(self, warn))]
412 pub fn resolve(
413 &mut self,
414 tref: &PolyTraitRef<'tcx>,
415 warn: &impl Fn(&str),
417 ) -> Result<ImplExpr<'tcx>, String> {
418 use rustc_trait_selection::traits::{
419 BuiltinImplSource, ImplSource, ImplSourceUserDefinedData,
420 };
421 let tcx = self.tcx;
422 let drop_trait = tcx.lang_items().drop_trait().unwrap();
423
424 let erased_tref = normalize_bound_val(self.tcx, self.typing_env, *tref);
425 let trait_def_id = erased_tref.skip_binder().def_id;
426
427 let impl_source = shallow_resolve_trait_ref(tcx, self.typing_env.param_env, erased_tref);
428 let atom = match impl_source {
429 Ok(ImplSource::UserDefined(ImplSourceUserDefinedData {
430 impl_def_id,
431 args: generics,
432 ..
433 })) => {
434 let impl_exprs =
436 self.resolve_item_required_predicates(impl_def_id, generics, warn)?;
437 ImplExprAtom::Concrete {
438 def_id: impl_def_id,
439 generics,
440 impl_exprs,
441 }
442 }
443 Ok(ImplSource::Param(_)) => {
444 match self.resolve_local(erased_tref.upcast(self.tcx), warn)? {
445 Some(candidate) => candidate.into_impl_expr(tcx),
446 None => {
447 let msg = format!(
448 "Could not find a clause for `{tref:?}` in the item parameters"
449 );
450 warn(&msg);
451 ImplExprAtom::Error(msg)
452 }
453 }
454 }
455 Ok(ImplSource::Builtin(BuiltinImplSource::Object { .. }, _)) => ImplExprAtom::Dyn,
456 Ok(ImplSource::Builtin(_, _)) => {
457 let impl_exprs = self.resolve_item_implied_predicates(
462 trait_def_id,
463 erased_tref.skip_binder().args,
464 warn,
465 )?;
466 let types = tcx
467 .associated_items(trait_def_id)
468 .in_definition_order()
469 .filter(|assoc| matches!(assoc.kind, AssocKind::Type { .. }))
470 .filter_map(|assoc| {
471 let ty =
472 Ty::new_projection(tcx, assoc.def_id, erased_tref.skip_binder().args);
473 let ty = erase_and_norm(tcx, self.typing_env, ty);
474 if let TyKind::Alias(_, alias_ty) = ty.kind() {
475 if alias_ty.def_id == assoc.def_id {
476 return None;
485 }
486 }
487 Some((assoc.def_id, ty))
488 })
489 .collect();
490 ImplExprAtom::Builtin {
491 r#trait: *tref,
492 impl_exprs,
493 types,
494 }
495 }
496 Err(CodegenObligationError::Unimplemented)
498 if erased_tref.skip_binder().def_id == drop_trait =>
499 {
500 let mut resolve_drop = |ty: Ty<'tcx>| {
504 let tref = ty::Binder::dummy(ty::TraitRef::new(tcx, drop_trait, [ty]));
505 self.resolve(&tref, warn)
506 };
507 let find_drop_impl = |ty: Ty<'tcx>| {
508 let mut dtor = None;
509 tcx.for_each_relevant_impl(drop_trait, ty, |impl_did| {
510 dtor = Some(impl_did);
511 });
512 dtor
513 };
514 let ty = erased_tref.skip_binder().args[0].as_type().unwrap();
516 match ty.kind() {
518 ty::Bool
520 | ty::Char
521 | ty::Int(..)
522 | ty::Uint(..)
523 | ty::Float(..)
524 | ty::Foreign(..)
525 | ty::Str
526 | ty::RawPtr(..)
527 | ty::Ref(..)
528 | ty::FnDef(..)
529 | ty::FnPtr(..)
530 | ty::UnsafeBinder(..)
531 | ty::Never => ImplExprAtom::Drop(DropData::Noop),
532 ty::Array(inner_ty, _) | ty::Pat(inner_ty, _) | ty::Slice(inner_ty) => {
533 ImplExprAtom::Drop(DropData::Glue {
534 ty,
535 impl_exprs: vec![resolve_drop(*inner_ty)?],
536 })
537 }
538 ty::Tuple(tys) => ImplExprAtom::Drop(DropData::Glue {
539 ty,
540 impl_exprs: tys.iter().map(resolve_drop).try_collect()?,
541 }),
542 ty::Adt(..) if let Some(_) = find_drop_impl(ty) => {
543 let msg = format!("Cannot resolve clause `{tref:?}`");
546 warn(&msg);
547 ImplExprAtom::Error(msg)
548 }
549 ty::Adt(_, args)
550 | ty::Closure(_, args)
551 | ty::Coroutine(_, args)
552 | ty::CoroutineClosure(_, args)
553 | ty::CoroutineWitness(_, args) => ImplExprAtom::Drop(DropData::Glue {
554 ty,
555 impl_exprs: args
556 .iter()
557 .filter_map(|arg| arg.as_type())
558 .map(resolve_drop)
559 .try_collect()?,
560 }),
561 ty::Dynamic(..) => ImplExprAtom::Dyn,
564 ty::Param(..) | ty::Alias(..) | ty::Bound(..) => {
565 if self.options.resolve_drop {
566 match self.resolve_local(erased_tref.upcast(self.tcx), warn)? {
569 Some(candidate) => candidate.into_impl_expr(tcx),
570 None => {
571 let msg =
572 format!("Cannot find virtual `Drop` clause: `{tref:?}`");
573 warn(&msg);
574 ImplExprAtom::Error(msg)
575 }
576 }
577 } else {
578 ImplExprAtom::Drop(DropData::Implicit)
579 }
580 }
581
582 ty::Placeholder(..) | ty::Infer(..) | ty::Error(..) => {
583 let msg = format!(
584 "Cannot resolve clause `{tref:?}` \
585 because of a type error"
586 );
587 warn(&msg);
588 ImplExprAtom::Error(msg)
589 }
590 }
591 }
592 Err(e) => {
593 let msg = format!(
594 "Could not find a clause for `{tref:?}` \
595 in the current context: `{e:?}`"
596 );
597 warn(&msg);
598 ImplExprAtom::Error(msg)
599 }
600 };
601
602 Ok(ImplExpr {
603 r#impl: atom,
604 r#trait: *tref,
605 })
606 }
607
608 pub fn resolve_item_required_predicates(
610 &mut self,
611 def_id: DefId,
612 generics: GenericArgsRef<'tcx>,
613 warn: &impl Fn(&str),
615 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
616 let tcx = self.tcx;
617 self.resolve_predicates(
618 generics,
619 required_predicates(tcx, def_id, self.options),
620 warn,
621 )
622 }
623
624 pub fn resolve_item_implied_predicates(
626 &mut self,
627 def_id: DefId,
628 generics: GenericArgsRef<'tcx>,
629 warn: &impl Fn(&str),
631 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
632 let tcx = self.tcx;
633 self.resolve_predicates(
634 generics,
635 implied_predicates(tcx, def_id, self.options),
636 warn,
637 )
638 }
639
640 pub fn resolve_predicates(
643 &mut self,
644 generics: GenericArgsRef<'tcx>,
645 predicates: GenericPredicates<'tcx>,
646 warn: &impl Fn(&str),
648 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
649 let tcx = self.tcx;
650 predicates
651 .predicates
652 .iter()
653 .map(|(clause, _span)| *clause)
654 .filter_map(|clause| clause.as_trait_clause())
655 .map(|trait_pred| trait_pred.map_bound(|p| p.trait_ref))
656 .map(|trait_ref| EarlyBinder::bind(trait_ref).instantiate(tcx, generics))
658 .map(|trait_ref| self.resolve(&trait_ref, warn))
660 .collect()
661 }
662}
663
664pub fn shallow_resolve_trait_ref<'tcx>(
672 tcx: TyCtxt<'tcx>,
673 param_env: ParamEnv<'tcx>,
674 trait_ref: PolyTraitRef<'tcx>,
675) -> Result<ImplSource<'tcx, ()>, CodegenObligationError> {
676 use rustc_infer::infer::TyCtxtInferExt;
677 use rustc_middle::traits::CodegenObligationError;
678 use rustc_middle::ty::TypeVisitableExt;
679 use rustc_trait_selection::traits::{
680 Obligation, ObligationCause, ObligationCtxt, SelectionContext, SelectionError,
681 };
682 let infcx = tcx
685 .infer_ctxt()
686 .ignoring_regions()
687 .build(TypingMode::PostAnalysis);
688 let mut selcx = SelectionContext::new(&infcx);
689
690 let obligation_cause = ObligationCause::dummy();
691 let obligation = Obligation::new(tcx, obligation_cause, param_env, trait_ref);
692
693 let selection = match selcx.poly_select(&obligation) {
694 Ok(Some(selection)) => selection,
695 Ok(None) => return Err(CodegenObligationError::Ambiguity),
696 Err(SelectionError::Unimplemented) => return Err(CodegenObligationError::Unimplemented),
697 Err(_) => return Err(CodegenObligationError::Ambiguity),
698 };
699
700 let ocx = ObligationCtxt::new(&infcx);
705 let impl_source = selection.map(|obligation| {
706 ocx.register_obligation(obligation.clone());
707 ()
708 });
709
710 let errors = ocx.select_all_or_error();
711 if !errors.is_empty() {
712 return Err(CodegenObligationError::Ambiguity);
713 }
714
715 let impl_source = infcx.resolve_vars_if_possible(impl_source);
716 let impl_source = tcx.erase_regions(impl_source);
717
718 if impl_source.has_infer() {
719 return Err(CodegenObligationError::Ambiguity);
721 }
722
723 Ok(impl_source)
724}