Skip to content

Commit cca458c

Browse files
committed
Rust: Address review comments and handle ! types in type inference
1 parent 5697a7e commit cca458c

File tree

6 files changed

+114
-32
lines changed

6 files changed

+114
-32
lines changed

‎rust/ql/lib/codeql/rust/internal/TypeInference.qll‎

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,14 @@ import M2
232232
module Consistency {
233233
import M2::Consistency
234234

235+
private Type inferCertainTypeAdj(AstNode n, TypePath path) {
236+
result = CertainTypeInference::inferCertainType(n, path) and
237+
not result = TNeverType()
238+
}
239+
235240
predicate nonUniqueCertainType(AstNode n, TypePath path, Type t) {
236-
strictcount(CertainTypeInference::inferCertainType(n, path)) > 1 and
237-
t = CertainTypeInference::inferCertainType(n, path) and
241+
strictcount(inferCertainTypeAdj(n, path)) > 1 and
242+
t = inferCertainTypeAdj(n, path) and
238243
// Suppress the inconsistency if `n` is a self parameter and the type
239244
// mention for the self type has multiple types for a path.
240245
not exists(ImplItemNode impl, TypePath selfTypePath |
@@ -291,6 +296,17 @@ private Type inferAnnotatedType(AstNode n, TypePath path) {
291296
result = n.(ShorthandSelfParameterMention).resolveTypeAt(path)
292297
}
293298

299+
/**
300+
* Holds if `me` is a call to the `panic!` macro.
301+
*
302+
* `panic!` needs special treatment, because it expands to a block expression
303+
* that looks like it should have type `()` instead of the correct `!` type.
304+
*/
305+
pragma[nomagic]
306+
private predicate isPanicMacroCall(MacroExpr me) {
307+
me.getMacroCall().resolveMacro().(MacroRules).getName().getText() = "panic"
308+
}
309+
294310
/** Module for inferring certain type information. */
295311
module CertainTypeInference {
296312
pragma[nomagic]
@@ -443,6 +459,14 @@ module CertainTypeInference {
443459
or
444460
result = inferCastExprType(n, path)
445461
or
462+
exprHasUnitType(n) and
463+
path.isEmpty() and
464+
result instanceof UnitType
465+
or
466+
isPanicMacroCall(n) and
467+
path.isEmpty() and
468+
result instanceof NeverType
469+
or
446470
infersCertainTypeAt(n, path, result.getATypeParameter())
447471
}
448472

@@ -579,7 +603,8 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
579603
n2 = be.getRhs()
580604
)
581605
or
582-
n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion()
606+
n1 = n2.(MacroExpr).getMacroCall().getMacroCallExpansion() and
607+
not isPanicMacroCall(n2)
583608
or
584609
n1 = n2.(MacroPat).getMacroCall().getMacroCallExpansion()
585610
or
@@ -931,14 +956,17 @@ private predicate functionInfoBlanketLike(
931956
*/
932957
bindingset[path, type]
933958
private predicate isComplexRootStripped(TypePath path, Type type) {
934-
path.isEmpty() and
935-
not validSelfType(type)
936-
or
937-
exists(TypeParameter tp |
938-
complexSelfRoot(_, tp) and
939-
path = TypePath::singleton(tp) and
940-
exists(type)
941-
)
959+
(
960+
path.isEmpty() and
961+
not validSelfType(type)
962+
or
963+
exists(TypeParameter tp |
964+
complexSelfRoot(_, tp) and
965+
path = TypePath::singleton(tp) and
966+
exists(type)
967+
)
968+
) and
969+
type != TNeverType()
942970
}
943971

944972
/**
@@ -1540,7 +1568,8 @@ private module MethodResolution {
15401568
MethodCall getMethodCall() { result = mc_ }
15411569

15421570
Type getTypeAt(TypePath path) {
1543-
result = mc_.getACandidateReceiverTypeAtSubstituteLookupTraits(derefChain, borrow, path)
1571+
result = mc_.getACandidateReceiverTypeAtSubstituteLookupTraits(derefChain, borrow, path) and
1572+
not result = TNeverType()
15441573
}
15451574

15461575
pragma[nomagic]
@@ -2810,7 +2839,8 @@ private predicate isReturnExprCfgAncestor(AstNode n) {
28102839
pragma[nomagic]
28112840
predicate isUnitBlockExpr(BlockExpr be) {
28122841
not be.getStmtList().hasTailExpr() and
2813-
not isReturnExprCfgAncestor(be)
2842+
not isReturnExprCfgAncestor(be) and
2843+
not be.hasLabel()
28142844
}
28152845

28162846
pragma[nomagic]
@@ -2831,6 +2861,15 @@ private Type inferBlockExprType(BlockExpr be, TypePath path) {
28312861
)
28322862
}
28332863

2864+
pragma[nomagic]
2865+
private predicate exprHasUnitType(Expr e) {
2866+
e = any(IfExpr ie | not ie.hasElse())
2867+
or
2868+
e instanceof WhileExpr
2869+
or
2870+
e instanceof ForExpr
2871+
}
2872+
28342873
final private class AwaitTarget extends Expr {
28352874
AwaitTarget() { this = any(AwaitExpr ae).getExpr() }
28362875

‎rust/ql/lib/codeql/rust/internal/TypeMention.qll‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,11 @@ TypeMention getSelfParamTypeMention(SelfParam self) {
444444
}
445445

446446
/**
447-
* An element used to represent the implicit `()` return type of function.
447+
* An element used to represent the implicit `()` return type of a function.
448+
*
449+
* Since the implicit type does not appear in the AST, we (somewhat arbitrarily)
450+
* choose the name of the function as a type mention. This works because there
451+
* is a one-to-one correspondence between a function and its name.
448452
*/
449453
class ShorthandReturnTypeMention extends TypeMention instanceof Name {
450454
private Function f;

‎rust/ql/lib/codeql/rust/internal/typeinference/BlanketImplementation.qll‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ module SatisfiesBlanketConstraint<
9090

9191
Location getLocation() { result = at.getLocation() }
9292

93-
Type getTypeAt(TypePath path) { result = at.getTypeAt(blanketPath.appendInverse(path)) }
93+
Type getTypeAt(TypePath path) {
94+
result = at.getTypeAt(blanketPath.appendInverse(path)) and
95+
not result = TNeverType()
96+
}
9497

9598
string toString() { result = at.toString() + " [blanket at " + blanketPath.toString() + "]" }
9699
}

‎rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll‎

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class FunctionPosition extends TFunctionPosition {
4646
result = f.getParam(this.asPosition()).getTypeRepr()
4747
or
4848
this.isReturn() and
49-
result = f.getRetType().getTypeRepr()
49+
result = getReturnTypeMention(f)
5050
}
5151

5252
string toString() {
@@ -263,7 +263,10 @@ module ArgIsInstantiationOf<
263263
final private class ArgFinal = Arg;
264264

265265
private class ArgSubst extends ArgFinal {
266-
Type getTypeAt(TypePath path) { result = substituteLookupTraits(super.getTypeAt(path)) }
266+
Type getTypeAt(TypePath path) {
267+
result = substituteLookupTraits(super.getTypeAt(path)) and
268+
not result = TNeverType()
269+
}
267270
}
268271

269272
private module IsInstantiationOfInput implements
@@ -368,10 +371,10 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
368371
CallAndPos cp, Input::Call call, FunctionPosition pos, int rnk, Function f,
369372
TypeAbstraction abs, AssocFunctionType constraint
370373
) {
371-
cp = MkCallAndPos(call, pos) and
374+
cp = MkCallAndPos(call, pragma[only_bind_into](pos)) and
372375
call.hasTargetCand(abs, f) and
373-
toCheckRanked(abs, f, pos, rnk) and
374-
Input::toCheck(abs, f, pos, constraint)
376+
toCheckRanked(abs, f, pragma[only_bind_into](pos), rnk) and
377+
Input::toCheck(abs, f, pragma[only_bind_into](pos), constraint)
375378
}
376379

377380
pragma[nomagic]

‎rust/ql/test/library-tests/type-inference/main.rs‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,7 +2841,7 @@ mod block_types {
28412841
#[rustfmt::skip]
28422842
fn f1(cond: bool) -> i32 {
28432843
// Block that evaluates to unit
2844-
let a = { // $ MISSING: type=a:()
2844+
let a = { // $ type=a:()
28452845
if cond {
28462846
return 12;
28472847
}
@@ -2852,7 +2852,7 @@ mod block_types {
28522852
#[rustfmt::skip]
28532853
fn f2() -> i32 {
28542854
// Block that does not evaluate to unit
2855-
let b = 'label: { // $ MISSING: b:i32 SPURIOUS: certainType=b:()
2855+
let b = 'label: { // $ MISSING: b:i32
28562856
break 'label 12;
28572857
};
28582858
println!("b: {:?}", b);

0 commit comments

Comments
 (0)