1use hax_frontend_exporter_options::BoundsOptions;
6use itertools::{Either, Itertools};
7use std::collections::{HashMap, hash_map::Entry};
8
9use rustc_hir::def::DefKind;
10use rustc_hir::def_id::DefId;
11use rustc_middle::traits::CodegenObligationError;
12use rustc_middle::ty::{self, *};
13use rustc_trait_selection::traits::ImplSource;
14
15use super::utils::{
16 self, ToPolyTraitRef, erase_and_norm, implied_predicates, normalize_bound_val,
17 required_predicates, self_predicate,
18};
19
20#[derive(Debug, Clone)]
21pub enum PathChunk<'tcx> {
22 AssocItem {
23 item: AssocItem,
24 generic_args: GenericArgsRef<'tcx>,
26 predicate: PolyTraitPredicate<'tcx>,
28 index: usize,
30 },
31 Parent {
32 predicate: PolyTraitPredicate<'tcx>,
34 index: usize,
36 },
37}
38pub type Path<'tcx> = Vec<PathChunk<'tcx>>;
39
40#[derive(Debug, Clone)]
41pub enum ImplExprAtom<'tcx> {
42 Concrete {
44 def_id: DefId,
45 generics: GenericArgsRef<'tcx>,
46 },
47 LocalBound {
49 predicate: Predicate<'tcx>,
50 index: usize,
53 r#trait: PolyTraitRef<'tcx>,
54 path: Path<'tcx>,
55 },
56 SelfImpl {
58 r#trait: PolyTraitRef<'tcx>,
59 path: Path<'tcx>,
60 },
61 Dyn,
68 Builtin {
72 trait_data: BuiltinTraitData<'tcx>,
74 impl_exprs: Vec<ImplExpr<'tcx>>,
78 types: Vec<(DefId, Ty<'tcx>, Vec<ImplExpr<'tcx>>)>,
80 },
81 Error(String),
83}
84
85#[derive(Debug, Clone)]
86pub enum BuiltinTraitData<'tcx> {
87 Destruct(DestructData<'tcx>),
92 Other,
94}
95
96#[derive(Debug, Clone)]
97pub enum DestructData<'tcx> {
98 Noop,
100 Implicit,
105 Glue {
107 ty: Ty<'tcx>,
109 },
110}
111
112#[derive(Clone, Debug)]
113pub struct ImplExpr<'tcx> {
114 pub r#trait: PolyTraitRef<'tcx>,
116 pub r#impl: ImplExprAtom<'tcx>,
118}
119
120#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
123pub enum BoundPredicateOrigin {
124 SelfPred,
127 Item(usize),
130}
131
132#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
133pub struct AnnotatedTraitPred<'tcx> {
134 pub origin: BoundPredicateOrigin,
135 pub clause: PolyTraitPredicate<'tcx>,
136}
137
138fn initial_self_pred<'tcx>(
141 tcx: TyCtxt<'tcx>,
142 def_id: rustc_span::def_id::DefId,
143) -> Option<PolyTraitPredicate<'tcx>> {
144 use DefKind::*;
145 let trait_def_id = match tcx.def_kind(def_id) {
146 Trait | TraitAlias => def_id,
147 AssocTy => tcx.parent(def_id),
151 _ => return None,
152 };
153 let self_pred = self_predicate(tcx, trait_def_id).upcast(tcx);
154 Some(self_pred)
155}
156
157fn local_bound_predicates<'tcx>(
160 tcx: TyCtxt<'tcx>,
161 def_id: rustc_span::def_id::DefId,
162 options: BoundsOptions,
163) -> Vec<PolyTraitPredicate<'tcx>> {
164 fn acc_predicates<'tcx>(
165 tcx: TyCtxt<'tcx>,
166 def_id: rustc_span::def_id::DefId,
167 options: BoundsOptions,
168 predicates: &mut Vec<PolyTraitPredicate<'tcx>>,
169 ) {
170 use DefKind::*;
171 match tcx.def_kind(def_id) {
172 AssocTy | AssocFn | AssocConst | Closure | Ctor(..) | Variant => {
174 let parent = tcx.parent(def_id);
175 acc_predicates(tcx, parent, options, predicates);
176 }
177 _ => {}
178 }
179 predicates.extend(
180 required_predicates(tcx, def_id, options)
181 .iter()
182 .map(|(clause, _span)| *clause)
183 .filter_map(|clause| clause.as_trait_clause()),
184 );
185 }
186
187 let mut predicates = vec![];
188 acc_predicates(tcx, def_id, options, &mut predicates);
189 predicates
190}
191
192#[tracing::instrument(level = "trace", skip(tcx))]
193fn parents_trait_predicates<'tcx>(
194 tcx: TyCtxt<'tcx>,
195 pred: PolyTraitPredicate<'tcx>,
196 options: BoundsOptions,
197) -> Vec<PolyTraitPredicate<'tcx>> {
198 let self_trait_ref = pred.to_poly_trait_ref();
199 implied_predicates(tcx, pred.def_id(), options)
200 .iter()
201 .map(|(clause, _span)| *clause)
202 .map(|clause| clause.instantiate_supertrait(tcx, self_trait_ref))
205 .filter_map(|pred| pred.as_trait_clause())
206 .collect()
207}
208
209#[derive(Debug, Clone)]
212struct Candidate<'tcx> {
213 path: Path<'tcx>,
214 pred: PolyTraitPredicate<'tcx>,
215 origin: AnnotatedTraitPred<'tcx>,
216}
217
218impl<'tcx> Candidate<'tcx> {
219 fn into_impl_expr(self, tcx: TyCtxt<'tcx>) -> ImplExprAtom<'tcx> {
220 let path = self.path;
221 let r#trait = self.origin.clause.to_poly_trait_ref();
222 match self.origin.origin {
223 BoundPredicateOrigin::SelfPred => ImplExprAtom::SelfImpl { r#trait, path },
224 BoundPredicateOrigin::Item(index) => ImplExprAtom::LocalBound {
225 predicate: self.origin.clause.upcast(tcx),
226 index,
227 r#trait,
228 path,
229 },
230 }
231 }
232}
233
234#[derive(Clone)]
236pub struct PredicateSearcher<'tcx> {
237 tcx: TyCtxt<'tcx>,
238 typing_env: rustc_middle::ty::TypingEnv<'tcx>,
239 candidates: HashMap<PolyTraitPredicate<'tcx>, Candidate<'tcx>>,
241 options: BoundsOptions,
243 bound_clause_count: usize,
245}
246
247impl<'tcx> PredicateSearcher<'tcx> {
248 pub fn new_for_owner(tcx: TyCtxt<'tcx>, owner_id: DefId, options: BoundsOptions) -> Self {
250 let mut out = Self {
251 tcx,
252 typing_env: TypingEnv {
253 param_env: tcx.param_env(owner_id),
254 typing_mode: TypingMode::PostAnalysis,
255 },
256 candidates: Default::default(),
257 options,
258 bound_clause_count: 0,
259 };
260 out.insert_predicates(
261 initial_self_pred(tcx, owner_id).map(|clause| AnnotatedTraitPred {
262 origin: BoundPredicateOrigin::SelfPred,
263 clause,
264 }),
265 );
266 out.insert_bound_predicates(local_bound_predicates(tcx, owner_id, options));
267 out
268 }
269
270 pub fn insert_bound_predicates(
274 &mut self,
275 clauses: impl IntoIterator<Item = PolyTraitPredicate<'tcx>>,
276 ) {
277 let mut count = usize::MAX;
278 std::mem::swap(&mut count, &mut self.bound_clause_count);
280 self.insert_predicates(clauses.into_iter().map(|clause| {
281 let i = count;
282 count += 1;
283 AnnotatedTraitPred {
284 origin: BoundPredicateOrigin::Item(i),
285 clause,
286 }
287 }));
288 std::mem::swap(&mut count, &mut self.bound_clause_count);
289 }
290
291 pub fn set_param_env(&mut self, param_env: ParamEnv<'tcx>) {
294 self.typing_env.param_env = param_env;
295 }
296
297 fn insert_predicates(&mut self, preds: impl IntoIterator<Item = AnnotatedTraitPred<'tcx>>) {
300 self.insert_candidates(preds.into_iter().map(|clause| Candidate {
301 path: vec![],
302 pred: clause.clause,
303 origin: clause,
304 }))
305 }
306
307 fn insert_candidates(&mut self, candidates: impl IntoIterator<Item = Candidate<'tcx>>) {
310 let tcx = self.tcx;
311 let mut new_candidates = Vec::new();
313 for mut candidate in candidates {
314 candidate.pred = normalize_bound_val(tcx, self.typing_env, candidate.pred);
316 if let Entry::Vacant(entry) = self.candidates.entry(candidate.pred) {
317 entry.insert(candidate.clone());
318 new_candidates.push(candidate);
319 }
320 }
321 if !new_candidates.is_empty() {
322 self.insert_candidate_parents(new_candidates);
324 }
325 }
326
327 fn insert_candidate_parents(&mut self, new_candidates: Vec<Candidate<'tcx>>) {
331 let tcx = self.tcx;
332 let options = self.options;
335 self.insert_candidates(new_candidates.into_iter().flat_map(|candidate| {
336 parents_trait_predicates(tcx, candidate.pred, options)
337 .into_iter()
338 .enumerate()
339 .map(move |(index, parent_pred)| {
340 let mut parent_candidate = Candidate {
341 pred: parent_pred,
342 path: candidate.path.clone(),
343 origin: candidate.origin,
344 };
345 parent_candidate.path.push(PathChunk::Parent {
346 predicate: parent_pred,
347 index,
348 });
349 parent_candidate
350 })
351 }));
352 }
353
354 fn add_associated_type_refs(
356 &mut self,
357 ty: Binder<'tcx, Ty<'tcx>>,
358 warn: &impl Fn(&str),
360 ) -> Result<(), String> {
361 let tcx = self.tcx;
362 let TyKind::Alias(AliasTyKind::Projection, alias_ty) = ty.skip_binder().kind() else {
364 return Ok(());
365 };
366 let trait_ref = ty.rebind(alias_ty.trait_ref(tcx)).upcast(tcx);
367
368 let Some(trait_candidate) = self.resolve_local(trait_ref, warn)? else {
371 return Ok(());
372 };
373
374 let item_bounds = implied_predicates(tcx, alias_ty.def_id, self.options);
376 let item_bounds = item_bounds
377 .iter()
378 .map(|(clause, _span)| *clause)
379 .filter_map(|pred| pred.as_trait_clause())
380 .map(|pred| EarlyBinder::bind(pred).instantiate(tcx, alias_ty.args))
382 .enumerate();
383
384 self.insert_candidates(item_bounds.map(|(index, pred)| {
386 let mut candidate = Candidate {
387 path: trait_candidate.path.clone(),
388 pred,
389 origin: trait_candidate.origin,
390 };
391 candidate.path.push(PathChunk::AssocItem {
392 item: tcx.associated_item(alias_ty.def_id),
393 generic_args: alias_ty.args,
394 predicate: pred,
395 index,
396 });
397 candidate
398 }));
399
400 Ok(())
401 }
402
403 fn resolve_local(
406 &mut self,
407 target: PolyTraitPredicate<'tcx>,
408 warn: &impl Fn(&str),
410 ) -> Result<Option<Candidate<'tcx>>, String> {
411 tracing::trace!("Looking for {target:?}");
412
413 let ret = self.candidates.get(&target).cloned();
415 if ret.is_some() {
416 return Ok(ret);
417 }
418
419 self.add_associated_type_refs(target.self_ty(), warn)?;
421
422 let ret = self.candidates.get(&target).cloned();
423 if ret.is_none() {
424 tracing::trace!(
425 "Couldn't find {target:?} in: [\n{}]",
426 self.candidates
427 .iter()
428 .map(|(_, c)| format!(" - {:?}\n", c.pred))
429 .join("")
430 );
431 }
432 Ok(ret)
433 }
434
435 #[tracing::instrument(level = "trace", skip(self, warn))]
437 pub fn resolve(
438 &mut self,
439 tref: &PolyTraitRef<'tcx>,
440 warn: &impl Fn(&str),
442 ) -> Result<ImplExpr<'tcx>, String> {
443 use rustc_trait_selection::traits::{
444 BuiltinImplSource, ImplSource, ImplSourceUserDefinedData,
445 };
446 let tcx = self.tcx;
447 let destruct_trait = tcx.lang_items().destruct_trait().unwrap();
448
449 let erased_tref = normalize_bound_val(self.tcx, self.typing_env, *tref);
450 let trait_def_id = erased_tref.skip_binder().def_id;
451
452 let error = |msg: String| {
453 warn(&msg);
454 Ok(ImplExpr {
455 r#impl: ImplExprAtom::Error(msg),
456 r#trait: *tref,
457 })
458 };
459
460 let impl_source = shallow_resolve_trait_ref(tcx, self.typing_env.param_env, erased_tref);
461 let atom = match impl_source {
462 Ok(ImplSource::UserDefined(ImplSourceUserDefinedData {
463 impl_def_id,
464 args: generics,
465 ..
466 })) => ImplExprAtom::Concrete {
467 def_id: impl_def_id,
468 generics,
469 },
470 Ok(ImplSource::Param(_)) => {
471 match self.resolve_local(erased_tref.upcast(self.tcx), warn)? {
472 Some(candidate) => candidate.into_impl_expr(tcx),
473 None => {
474 let msg = format!(
475 "Could not find a clause for `{tref:?}` in the item parameters"
476 );
477 return error(msg);
478 }
479 }
480 }
481 Ok(ImplSource::Builtin(BuiltinImplSource::Object { .. }, _)) => ImplExprAtom::Dyn,
482 Ok(ImplSource::Builtin(_, _)) => {
483 let impl_exprs = self.resolve_item_implied_predicates(
488 trait_def_id,
489 erased_tref.skip_binder().args,
490 warn,
491 )?;
492 let types = tcx
493 .associated_items(trait_def_id)
494 .in_definition_order()
495 .filter(|assoc| matches!(assoc.kind, AssocKind::Type { .. }))
496 .filter_map(|assoc| {
497 let ty =
498 Ty::new_projection(tcx, assoc.def_id, erased_tref.skip_binder().args);
499 let ty = erase_and_norm(tcx, self.typing_env, ty);
500 if let TyKind::Alias(_, alias_ty) = ty.kind() {
501 if alias_ty.def_id == assoc.def_id {
502 return None;
511 }
512 }
513 let impl_exprs = self
514 .resolve_item_implied_predicates(
515 assoc.def_id,
516 erased_tref.skip_binder().args,
517 warn,
518 )
519 .ok()?;
520 Some((assoc.def_id, ty, impl_exprs))
521 })
522 .collect();
523
524 let trait_data = if erased_tref.skip_binder().def_id == destruct_trait {
525 let ty = erased_tref.skip_binder().args[0].as_type().unwrap();
526 let destruct_data = match ty.kind() {
528 ty::Bool
530 | ty::Char
531 | ty::Int(..)
532 | ty::Uint(..)
533 | ty::Float(..)
534 | ty::Foreign(..)
535 | ty::Str
536 | ty::RawPtr(..)
537 | ty::Ref(..)
538 | ty::FnDef(..)
539 | ty::FnPtr(..)
540 | ty::UnsafeBinder(..)
541 | ty::Never => Either::Left(DestructData::Noop),
542 ty::Tuple(tys) if tys.is_empty() => Either::Left(DestructData::Noop),
543 ty::Array(..)
544 | ty::Pat(..)
545 | ty::Slice(..)
546 | ty::Tuple(..)
547 | ty::Adt(..)
548 | ty::Closure(..)
549 | ty::Coroutine(..)
550 | ty::CoroutineClosure(..)
551 | ty::CoroutineWitness(..) => Either::Left(DestructData::Glue { ty }),
552 ty::Dynamic(..) => Either::Right(ImplExprAtom::Dyn),
555 ty::Param(..) | ty::Alias(..) | ty::Bound(..) => {
556 if self.options.resolve_destruct {
557 match self.resolve_local(erased_tref.upcast(self.tcx), warn)? {
560 Some(candidate) => Either::Right(candidate.into_impl_expr(tcx)),
561 None => {
562 let msg = format!(
563 "Cannot find virtual `Destruct` clause: `{tref:?}`"
564 );
565 return error(msg);
566 }
567 }
568 } else {
569 Either::Left(DestructData::Implicit)
570 }
571 }
572
573 ty::Placeholder(..) | ty::Infer(..) | ty::Error(..) => {
574 let msg = format!(
575 "Cannot resolve clause `{tref:?}` \
576 because of a type error"
577 );
578 return error(msg);
579 }
580 };
581 destruct_data.map_left(BuiltinTraitData::Destruct)
582 } else {
583 Either::Left(BuiltinTraitData::Other)
584 };
585 match trait_data {
586 Either::Left(trait_data) => ImplExprAtom::Builtin {
587 trait_data,
588 impl_exprs,
589 types,
590 },
591 Either::Right(atom) => atom,
592 }
593 }
594 Err(e) => {
595 let msg = format!(
596 "Could not find a clause for `{tref:?}` \
597 in the current context: `{e:?}`"
598 );
599 return error(msg);
600 }
601 };
602
603 Ok(ImplExpr {
604 r#impl: atom,
605 r#trait: *tref,
606 })
607 }
608
609 pub fn resolve_item_required_predicates(
611 &mut self,
612 def_id: DefId,
613 generics: GenericArgsRef<'tcx>,
614 warn: &impl Fn(&str),
616 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
617 let tcx = self.tcx;
618 self.resolve_predicates(
619 generics,
620 required_predicates(tcx, def_id, self.options),
621 warn,
622 )
623 }
624
625 pub fn resolve_item_implied_predicates(
627 &mut self,
628 def_id: DefId,
629 generics: GenericArgsRef<'tcx>,
630 warn: &impl Fn(&str),
632 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
633 let tcx = self.tcx;
634 self.resolve_predicates(
635 generics,
636 implied_predicates(tcx, def_id, self.options),
637 warn,
638 )
639 }
640
641 pub fn resolve_predicates(
644 &mut self,
645 generics: GenericArgsRef<'tcx>,
646 predicates: utils::Predicates<'tcx>,
647 warn: &impl Fn(&str),
649 ) -> Result<Vec<ImplExpr<'tcx>>, String> {
650 let tcx = self.tcx;
651 predicates
652 .iter()
653 .map(|(clause, _span)| *clause)
654 .filter_map(|clause| clause.as_trait_clause())
655 .map(|trait_pred| trait_pred.map_bound(|p| p.trait_ref))
656 .map(|trait_ref| EarlyBinder::bind(trait_ref).instantiate(tcx, generics))
658 .map(|trait_ref| self.resolve(&trait_ref, warn))
660 .collect()
661 }
662}
663
664pub fn shallow_resolve_trait_ref<'tcx>(
672 tcx: TyCtxt<'tcx>,
673 param_env: ParamEnv<'tcx>,
674 trait_ref: PolyTraitRef<'tcx>,
675) -> Result<ImplSource<'tcx, ()>, CodegenObligationError> {
676 use rustc_infer::infer::TyCtxtInferExt;
677 use rustc_middle::traits::CodegenObligationError;
678 use rustc_middle::ty::TypeVisitableExt;
679 use rustc_trait_selection::traits::{
680 Obligation, ObligationCause, ObligationCtxt, SelectionContext, SelectionError,
681 };
682 let infcx = tcx
685 .infer_ctxt()
686 .ignoring_regions()
687 .build(TypingMode::PostAnalysis);
688 let mut selcx = SelectionContext::new(&infcx);
689
690 let obligation_cause = ObligationCause::dummy();
691 let obligation = Obligation::new(tcx, obligation_cause, param_env, trait_ref);
692
693 let selection = match selcx.poly_select(&obligation) {
694 Ok(Some(selection)) => selection,
695 Ok(None) => return Err(CodegenObligationError::Ambiguity),
696 Err(SelectionError::Unimplemented) => return Err(CodegenObligationError::Unimplemented),
697 Err(_) => return Err(CodegenObligationError::Ambiguity),
698 };
699
700 let ocx = ObligationCtxt::new(&infcx);
705 let impl_source = selection.map(|obligation| {
706 ocx.register_obligation(obligation.clone());
707 ()
708 });
709
710 let errors = ocx.evaluate_obligations_error_on_ambiguity();
711 if !errors.is_empty() {
712 return Err(CodegenObligationError::Ambiguity);
713 }
714
715 let impl_source = infcx.resolve_vars_if_possible(impl_source);
716 let impl_source = tcx.erase_and_anonymize_regions(impl_source);
717
718 if impl_source.has_infer() {
719 return Err(CodegenObligationError::Ambiguity);
721 }
722
723 Ok(impl_source)
724}