Skip to content

Commit 21d0384

Browse files
committed
wip6
1 parent ba8029f commit 21d0384

2 files changed

Lines changed: 53 additions & 67 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ private module Input3 implements InputSig3 {
434434
}
435435

436436
bindingset[derefChainBorrow]
437-
Type inferCallTypeBottomUp(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) {
437+
Type inferCallTypeAtPos(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) {
438438
result = call.(FunctionCallMatchingInput::Access).getInferredType(derefChainBorrow, pos, path)
439439
}
440440

@@ -450,14 +450,14 @@ private module Input3 implements InputSig3 {
450450
)
451451
}
452452

453-
Type inferCallArgumentTypeTopDown(AstNode n, TypePath path) {
453+
Type inferCallArgumentType(AstNode n, TypePath path) {
454454
exists(FunctionCallMatchingInput::Access call, FunctionPosition pos |
455-
result = inferCallArgumentTypeTopDown(call, pos, n, _, _, path) and
455+
result = inferCallArgumentType(call, pos, n, _, _, path) and
456456
not call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos)
457457
)
458458
or
459459
exists(FunctionCallMatchingInput::Access a |
460-
result = inferFunctionCallSelfArgumentTypeTopDown(a, n, DerefChain::nil(), path) and
460+
result = inferFunctionCallSelfArgumentType(a, n, DerefChain::nil(), path) and
461461
if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver()
462462
then not path.isEmpty()
463463
else any()
@@ -623,20 +623,6 @@ private module Input3 implements InputSig3 {
623623
)
624624
}
625625

626-
Type inferTypeTopDown(AstNode n, TypePath path) {
627-
result = inferTypeFromAnnotationTopDown(n, path)
628-
or
629-
result = inferClosureExprBodyTypeTopDown(n, path)
630-
or
631-
exists(FunctionPosition pos | not pos.isReturn() |
632-
result = inferConstructionType(n, pos, path)
633-
or
634-
result = inferOperationType(n, pos, path)
635-
)
636-
or
637-
result = inferFieldExprType(n, path, true)
638-
}
639-
640626
Type inferType(AstNode n, TypePath path) {
641627
result = M3::inferType(n, path)
642628
or
@@ -672,6 +658,20 @@ private module Input3 implements InputSig3 {
672658
or
673659
result = inferUnknownType(n, path)
674660
}
661+
662+
Type inferTypeTopDown(AstNode n, TypePath path) {
663+
result = inferTypeFromAnnotationTopDown(n, path)
664+
or
665+
result = inferClosureExprBodyTypeTopDown(n, path)
666+
or
667+
exists(FunctionPosition pos | not pos.isReturn() |
668+
result = inferConstructionType(n, pos, path)
669+
or
670+
result = inferOperationType(n, pos, path)
671+
)
672+
or
673+
result = inferFieldExprType(n, path, true)
674+
}
675675
}
676676

677677
private module M3 = Make3<Input3>;
@@ -2827,13 +2827,13 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
28272827
}
28282828

28292829
pragma[nomagic]
2830-
private Type inferCallArgumentTypeTopDown(
2830+
private Type inferCallArgumentType(
28312831
FunctionCallMatchingInput::Access call, FunctionPosition pos, AstNode n, DerefChain derefChain,
28322832
BorrowKind borrow, TypePath path
28332833
) {
28342834
exists(string derefChainBorrow |
28352835
FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) and
2836-
result = M3::inferCallArgumentTypeTopDown(call, derefChainBorrow, pos, n, path)
2836+
result = M3::inferCallArgumentType(call, derefChainBorrow, pos, n, path)
28372837
)
28382838
}
28392839

@@ -2845,12 +2845,12 @@ private Type inferCallArgumentTypeTopDown(
28452845
* empty, at which point the inferred type can be applied back to `n`.
28462846
*/
28472847
pragma[nomagic]
2848-
private Type inferFunctionCallSelfArgumentTypeTopDown(
2848+
private Type inferFunctionCallSelfArgumentType(
28492849
FunctionCallMatchingInput::Access call, AstNode n, DerefChain derefChain, TypePath path
28502850
) {
28512851
exists(FunctionPosition pos, BorrowKind borrow, TypePath path0 |
28522852
call.(AssocFunctionResolution::AssocFunctionCall).hasReceiverAtPos(pos) and
2853-
result = inferCallArgumentTypeTopDown(call, pos, n, derefChain, borrow, path0)
2853+
result = inferCallArgumentType(call, pos, n, derefChain, borrow, path0)
28542854
|
28552855
borrow.isNoBorrow() and
28562856
path = path0
@@ -2867,7 +2867,7 @@ private Type inferFunctionCallSelfArgumentTypeTopDown(
28672867
DerefChain derefChain0, Type t0, TypePath path0, DerefImplItemNode impl, Type selfParamType,
28682868
TypePath selfPath
28692869
|
2870-
t0 = inferFunctionCallSelfArgumentTypeTopDown(call, n, derefChain0, path0) and
2870+
t0 = inferFunctionCallSelfArgumentType(call, n, derefChain0, path0) and
28712871
derefChain0.isCons(impl, derefChain) and
28722872
selfParamType = impl.resolveSelfTypeAt(selfPath)
28732873
|

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,7 +2251,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22512251
* about the call via the call target, such as the return type.
22522252
*/
22532253
bindingset[ctx]
2254-
default Type inferCallTypeBottomUp(
2254+
default Type inferCallTypeAtPos(
22552255
Call call, CallResolutionContext ctx, TypePosition pos, TypePath path
22562256
) {
22572257
result = inferType(call.getNodeAt(pos), path) and
@@ -2278,9 +2278,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22782278
* type.
22792279
*
22802280
* When no call-context based post-processing is needed, simply implement this
2281-
* predicate as `result = inferCallArgumentTypeTopDown(_, _, _, n, path)`.
2281+
* predicate as `result = inferCallArgumentType(_, _, _, n, path)`.
22822282
*/
2283-
Type inferCallArgumentTypeTopDown(AstNode n, TypePath path);
2283+
Type inferCallArgumentType(AstNode n, TypePath path);
22842284

22852285
/**
22862286
* Holds if `n1` having certain type `t` at `path1` implies that `n2` has
@@ -2475,14 +2475,33 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
24752475
}
24762476

24772477
pragma[nomagic]
2478-
private Type inferTypeFromStepRev(AstNode n, TypePath path) {
2478+
private predicate hasUnknownTypeAt(AstNode n, TypePath path) {
2479+
inferType(n, path) instanceof UnknownType
2480+
}
2481+
2482+
pragma[nomagic]
2483+
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }
2484+
2485+
pragma[nomagic]
2486+
private Type inferTypeTopDownCand(AstNode n, TypePath path) {
2487+
result = inferTypeTopDown(n, path)
2488+
or
2489+
result = inferCallArgumentType(n, path)
2490+
or
24792491
exists(TypePath path1, AstNode n2, TypePath path2, TypePath suffix |
24802492
result = inferType(n2, path2.appendInverse(suffix)) and
24812493
path = path1.append(suffix) and
24822494
step(n, path1, n2, path2)
24832495
)
24842496
}
24852497

2498+
pragma[nomagic]
2499+
private Type inferTypeTopDownCand(AstNode n, TypePath prefix, TypePath path) {
2500+
result = inferTypeTopDownCand(n, path) and
2501+
hasUnknownType(n) and
2502+
prefix = path.getAPrefix()
2503+
}
2504+
24862505
private Type inferType0(AstNode n, TypePath path) {
24872506
result = Input3::inferType(n, path)
24882507
or
@@ -2497,13 +2516,13 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
24972516
or
24982517
result = inferTypeFromStep(n, path)
24992518
or
2500-
result = TopDownTyping<inferTypeFromStepRev/2>::inferType(n, path)
2501-
or
25022519
result = inferCallReturnType(n, path)
25032520
or
2504-
result = TopDownTyping<inferCallArgumentTypeTopDown/2>::inferType(n, path)
2505-
or
2506-
result = TopDownTyping<inferTypeTopDown/2>::inferType(n, path)
2521+
// top-down inference
2522+
exists(TypePath prefix |
2523+
result = inferTypeTopDownCand(n, prefix, path) and
2524+
hasUnknownTypeAt(n, prefix)
2525+
)
25072526
}
25082527

25092528
/**
@@ -2581,7 +2600,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
25812600
class Access extends CallFinal {
25822601
bindingset[e]
25832602
Type getInferredType(AccessEnvironment e, AccessPosition apos, TypePath path) {
2584-
result = inferCallTypeBottomUp(this, e, apos, path)
2603+
result = inferCallTypeAtPos(this, e, apos, path)
25852604
}
25862605
}
25872606
}
@@ -2602,47 +2621,14 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
26022621
)
26032622
}
26042623

2605-
Type inferCallArgumentTypeTopDown(
2624+
Type inferCallArgumentType(
26062625
Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path
26072626
) {
26082627
result = inferCallType(call, ctx, pos, n, path) and
26092628
not pos.isReturn() and
26102629
hasUnknownType(n)
26112630
}
26122631

2613-
pragma[nomagic]
2614-
private predicate hasUnknownTypeAt(AstNode n, TypePath path) {
2615-
inferType(n, path) instanceof UnknownType
2616-
}
2617-
2618-
pragma[nomagic]
2619-
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }
2620-
2621-
private signature Type inferTypeTopDownSig(AstNode n, TypePath path);
2622-
2623-
/**
2624-
* Given a predicate `infer` for inferring the type of an AST node `n`
2625-
* top-down from a context, this module exposes the predicate `inferType`, which
2626-
* restricts type information to only flow top-down into `n` when `n` has an
2627-
* explicit unknown type.
2628-
*/
2629-
private module TopDownTyping<inferTypeTopDownSig/2 infer> {
2630-
pragma[nomagic]
2631-
private Type inferTypeTopDown(AstNode n, TypePath prefix, TypePath path) {
2632-
result = infer(n, path) and
2633-
hasUnknownType(n) and
2634-
prefix = path.getAPrefix()
2635-
}
2636-
2637-
pragma[nomagic]
2638-
Type inferType(AstNode n, TypePath path) {
2639-
exists(TypePath prefix |
2640-
result = inferTypeTopDown(n, prefix, path) and
2641-
hasUnknownTypeAt(n, prefix)
2642-
)
2643-
}
2644-
}
2645-
26462632
/**
26472633
* Gets the inferred root type of `n`, if any.
26482634
*/

0 commit comments

Comments
 (0)