Skip to content

Commit f43a701

Browse files
committed
dbg
1 parent 93ab989 commit f43a701

3 files changed

Lines changed: 173 additions & 322 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 106 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -420,17 +420,63 @@ private module Input3 implements InputSig3 {
420420
override AstNode getInitializer() { result = LetStmt.super.getInitializer() }
421421
}
422422

423-
class CallResolutionContext = FunctionCallMatchingInput::AccessEnvironment;
423+
abstract class Parameter extends Rust::ParamBase {
424+
abstract AstNode getPattern();
425+
426+
abstract TypeMention getType();
427+
}
428+
429+
private class SelfParamParameter extends Parameter, SelfParam {
430+
override AstNode getPattern() { result = this }
431+
432+
override TypeMention getType() { result = getSelfParamTypeMention(this) }
433+
}
434+
435+
private class ParamParameter extends Parameter, Param {
436+
override AstNode getPattern() { result = this.getPat() }
437+
438+
override TypeMention getType() { result = this.getTypeRepr() }
439+
}
440+
441+
final class Callable extends Rust::Callable {
442+
TypeParameter getTypeParameter(TypeParameterPosition ppos) {
443+
result = this.(FunctionDeclaration).getTypeParameter(ppos)
444+
}
424445

425-
final class Callable extends FunctionCallMatchingInput::Declaration {
426446
TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp) {
427447
result =
428448
tp.(TypeParamTypeParameter).getTypeParam().getAdditionalTypeBound(this, _).getTypeRepr()
429449
}
430450

451+
Parameter getParameter(int i) {
452+
i = 0 and
453+
result = this.getSelfParam()
454+
or
455+
exists(int pos | result = this.getParam(pos) |
456+
if this instanceof Method then i = pos + 1 else i = pos
457+
)
458+
}
459+
460+
TypeMention getReturnType() { result = getReturnTypeMention(this) }
461+
431462
AstNode getBody() { result = super.getBody() }
432463
}
433464

465+
Type getCallableReturnType(Callable c, TypePath path) {
466+
if c.(Function).isAsync()
467+
then
468+
path.isEmpty() and
469+
result = getFutureTraitType()
470+
or
471+
exists(TypePath suffix |
472+
result = getReturnTypeMention(c).getTypeAt(suffix) and
473+
path = TypePath::cons(getDynFutureOutputTypeParameter(), suffix)
474+
)
475+
else result = getReturnTypeMention(c).getTypeAt(path)
476+
}
477+
478+
class CallResolutionContext = FunctionCallMatchingInput::AccessEnvironment;
479+
434480
class Call extends FunctionCallMatchingInput::Access {
435481
Callable getTarget(string derefChainBorrow) { result = super.getTarget(derefChainBorrow) }
436482

@@ -444,23 +490,47 @@ private module Input3 implements InputSig3 {
444490

445491
bindingset[derefChainBorrow]
446492
Type inferCallArgumentType(Call call, string derefChainBorrow, int pos, TypePath path) {
493+
exists(FunctionPosition fpos |
494+
pos = fpos.asPosition() and
495+
result = call.(AssocFunctionResolution::OperationAssocFunctionCall).getTypeAt(fpos, path)
496+
)
497+
or
498+
not call instanceof Operation and
447499
result =
448500
call.(FunctionCallMatchingInput::Access).getInferredArgumentType(derefChainBorrow, pos, path)
449501
}
450502

451-
Type inferCallReturnType(Call call, TypePath path) {
503+
Type inferCallType(Call call, TypePath path) {
452504
exists(TypePath path0 |
453-
result = M3::inferCallReturnTypeDefault(call, _, path0) and
505+
result = M3::inferCallTypeDefault(call, _, path0) and
454506
// index expression `x[i]` desugars to `*x.index(i)`, so we must account for
455507
// the implicit deref
456-
if call instanceof IndexExpr then path0.isCons(getRefTypeParameter(_), path) else path = path0
508+
if call instanceof IndexExpr or call instanceof DerefExpr
509+
then path0.isCons(getRefTypeParameter(_), path)
510+
else path = path0
511+
)
512+
}
513+
514+
Type inferCallTypeContextual(Call call, TypePath path) {
515+
exists(TypePath path0 |
516+
result = inferType(call, path0) and
517+
// index expression `x[i]` desugars to `*x.index(i)`, so we must account for
518+
// the implicit deref
519+
if call instanceof IndexExpr or call instanceof DerefExpr
520+
then path0.isCons(getRefTypeParameter(_), path)
521+
else path = path0
457522
)
458523
}
459524

460525
Type inferCallArgumentTypeContextual(AstNode n, TypePath path) {
461-
exists(FunctionCallMatchingInput::Access call, FunctionPosition pos |
462-
result = inferCallArgumentTypeContextual(call, pos.asPosition(), n, _, _, path) and
463-
not call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos)
526+
exists(FunctionCallMatchingInput::Access call, FunctionPosition pos, TypePath path0 |
527+
result = inferCallArgumentTypeContextual(call, pos.asPosition(), n, _, _, path0) and
528+
not call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos) and
529+
if call.(AssocFunctionResolution::OperationAssocFunctionCall).implicitBorrowAt(pos, _)
530+
then
531+
// adjust for implicit borrow
532+
path0.isCons(getRefTypeParameter(_), path)
533+
else path = path0
464534
)
465535
or
466536
exists(FunctionCallMatchingInput::Access a |
@@ -537,60 +607,6 @@ private module Input3 implements InputSig3 {
537607
)
538608
}
539609

540-
class Operator extends Callable {
541-
private Method getSelfOrImpl() {
542-
result = this
543-
or
544-
this.implements(result)
545-
}
546-
547-
pragma[nomagic]
548-
private predicate borrowsAt(int pos) {
549-
exists(TraitItemNode t, string path, string method |
550-
this.getSelfOrImpl() = t.getAssocItem(method) and
551-
path = t.getCanonicalPath(_) and
552-
exists(int borrows | OperationImpl::isOverloaded(_, _, path, method, borrows) |
553-
pos = 0 and borrows >= 1
554-
or
555-
pos = 1 and
556-
borrows >= 2
557-
)
558-
)
559-
}
560-
561-
pragma[nomagic]
562-
private predicate derefsReturn() { this.getSelfOrImpl() = any(DerefTrait t).getDerefFunction() }
563-
564-
Type getParameterType(int pos, TypePath path) {
565-
exists(TypePath path0 | result = super.getParameterType(pos, path0) |
566-
if this.borrowsAt(pos) then path0.isCons(getRefTypeParameter(_), path) else path0 = path
567-
)
568-
}
569-
570-
Type getReturnType(TypePath path) {
571-
exists(TypePath path0 | result = super.getReturnType(path0) |
572-
if this.derefsReturn() then path0.isCons(getRefTypeParameter(_), path) else path0 = path
573-
)
574-
}
575-
}
576-
577-
class Operation extends AssocFunctionResolution::OperationAssocFunctionCall {
578-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }
579-
580-
AstNode getOperand(int i) {
581-
exists(FunctionPosition pos |
582-
result = this.getNodeAt(pos) and
583-
i = pos.asPosition()
584-
)
585-
}
586-
587-
Operator getTarget() {
588-
exists(ImplOrTraitItemNode i |
589-
result = this.resolveCallTarget(i, _, _, _) // mutual recursion
590-
)
591-
}
592-
}
593-
594610
final class Constructor = ConstructorImpl;
595611

596612
abstract private class ConstructorImpl extends Addressable {
@@ -722,33 +738,10 @@ private module Input3 implements InputSig3 {
722738
override AstNode getArgument(int i) { none() }
723739
}
724740

725-
abstract class Parameter extends Rust::ParamBase {
726-
abstract AstNode getPattern();
727-
728-
abstract TypeMention getType();
729-
}
730-
731-
private class SelfParamParameter extends Parameter, SelfParam {
732-
override AstNode getPattern() { result = this }
733-
734-
override TypeMention getType() {
735-
result = super.getTypeRepr() or
736-
result = this.(ShorthandSelfParameterMention)
737-
}
738-
}
739-
740-
private class ParamParameter extends Parameter, Param {
741-
override AstNode getPattern() { result = this.getPat() }
742-
743-
override TypeMention getType() { result = this.getTypeRepr() }
744-
}
745-
746-
class Closure extends Rust::ClosureExpr {
747-
Parameter getParameter(int i) { result = super.getParam(i) }
748-
749-
AstNode getBody() { result = super.getBody() }
750-
751-
TypeMention getReturnType() { result = this.getRetType().getTypeRepr() }
741+
class Closure extends Expr, Callable instanceof Rust::ClosureExpr {
742+
// Parameter getParameter(int i) { result = super.getParam(i) }
743+
// AstNode getBody() { result = super.getBody() }
744+
// TypeMention getReturnType() { result = this.getRetType().getTypeRepr() }
752745
}
753746

754747
class ClosureParameterPseudoType = T::ClosureParameterPseudoType;
@@ -975,8 +968,6 @@ private module Input3 implements InputSig3 {
975968
result = inferDeconstructionPatType(n, path)
976969
or
977970
result = inferUnknownType(n, path)
978-
// or
979-
// result = inferParamPatType(n, path)
980971
}
981972
}
982973

@@ -1060,26 +1051,24 @@ private class FunctionDeclaration extends Function {
10601051
i.asSome().getAnAssocItem() = this
10611052
)
10621053
}
1063-
1064-
pragma[nomagic]
1065-
Type getParameterType(int j, TypePath path) {
1066-
exists(FunctionPosition fpos | j = fpos.asPosition() |
1067-
result = fpos.getTypeMention(this).getTypeAt(path)
1068-
)
1069-
}
1070-
1071-
Type getReturnType(TypePath path) {
1072-
if this.isAsync()
1073-
then
1074-
path.isEmpty() and
1075-
result = getFutureTraitType()
1076-
or
1077-
exists(TypePath suffix |
1078-
result = getReturnTypeMention(this).getTypeAt(suffix) and
1079-
path = TypePath::cons(getDynFutureOutputTypeParameter(), suffix)
1080-
)
1081-
else result = getReturnTypeMention(this).getTypeAt(path)
1082-
}
1054+
// pragma[nomagic]
1055+
// Type getParameterType(int j, TypePath path) {
1056+
// exists(FunctionPosition fpos | j = fpos.asPosition() |
1057+
// result = fpos.getTypeMention(this).getTypeAt(path)
1058+
// )
1059+
// }
1060+
// Type getReturnType(TypePath path) {
1061+
// if this.isAsync()
1062+
// then
1063+
// path.isEmpty() and
1064+
// result = getFutureTraitType()
1065+
// or
1066+
// exists(TypePath suffix |
1067+
// result = getReturnTypeMention(this).getTypeAt(suffix) and
1068+
// path = TypePath::cons(getDynFutureOutputTypeParameter(), suffix)
1069+
// )
1070+
// else result = getReturnTypeMention(this).getTypeAt(path)
1071+
// }
10831072
}
10841073

10851074
private class AssocFunctionDeclaration extends FunctionDeclaration {
@@ -1284,10 +1273,8 @@ private module ContextualTyping {
12841273
* possibly via a constraint on another mentioned type parameter.
12851274
*/
12861275
pragma[nomagic]
1287-
private predicate assocFunctionMentionsTypeParameterAtNonRetPos(
1288-
FunctionDeclaration f, TypeParameter tp
1289-
) {
1290-
tp = f.getParameterType(_, _)
1276+
private predicate assocFunctionMentionsTypeParameterAtNonRetPos(Function f, TypeParameter tp) {
1277+
tp = f.(Input3::Callable).getParameter(_).getType().getTypeAt(_)
12911278
or
12921279
exists(TypeParameter mid |
12931280
assocFunctionMentionsTypeParameterAtNonRetPos(f, mid) and
@@ -1309,7 +1296,7 @@ private module ContextualTyping {
13091296
private predicate assocFunctionReturnContextTypedAt(
13101297
FunctionDeclaration f, TypePath path, TypeParameter tp
13111298
) {
1312-
tp = f.getReturnType(path) and
1299+
tp = Input3::getCallableReturnType(f, path) and
13131300
not assocFunctionMentionsTypeParameterAtNonRetPos(f, tp)
13141301
}
13151302

@@ -2175,7 +2162,7 @@ private module AssocFunctionResolution {
21752162
result = this.getOperand(pos.asPosition())
21762163
}
21772164

2178-
private predicate implicitBorrowAt(FunctionPosition pos, boolean isMutable) {
2165+
predicate implicitBorrowAt(FunctionPosition pos, boolean isMutable) {
21792166
exists(int borrows | this.isOverloaded(_, _, borrows) |
21802167
pos.asPosition() = 0 and
21812168
borrows >= 1 and
@@ -2871,11 +2858,6 @@ private module FunctionCallMatchingInput {
28712858

28722859
private class AssocFunctionCallAccess extends AccessImpl instanceof AssocFunctionResolution::AssocFunctionCall
28732860
{
2874-
AssocFunctionCallAccess() {
2875-
// handled in the `OperationMatchingInput` module
2876-
not this instanceof Operation
2877-
}
2878-
28792861
pragma[nomagic]
28802862
override Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
28812863
result =
@@ -2985,8 +2967,8 @@ private module FunctionCallMatchingInput {
29852967
derefChainBorrow = noDerefChainBorrow() and
29862968
exists(FunctionDeclaration f, TypeParameter tp |
29872969
f = super.resolveCallTargetViaPathResolution() and
2988-
tp = f.getReturnType(path) and
2989-
not tp = f.getParameterType(_, _) and
2970+
tp = Input3::getCallableReturnType(f, path) and
2971+
not tp = f.(Input3::Callable).getParameter(_).getType().getTypeAt(_) and
29902972
// check that no explicit type arguments have been supplied for `tp`
29912973
not exists(TypeArgumentPosition tapos |
29922974
this.hasTypeArgument(tapos) and

0 commit comments

Comments
 (0)