diff --git a/csharp/ql/lib/semmle/code/csharp/frameworks/System.qll b/csharp/ql/lib/semmle/code/csharp/frameworks/System.qll index 2681b9437b69..3872ad13dcb9 100644 --- a/csharp/ql/lib/semmle/code/csharp/frameworks/System.qll +++ b/csharp/ql/lib/semmle/code/csharp/frameworks/System.qll @@ -144,7 +144,6 @@ class SystemIComparableInterface extends SystemInterface { result.getDeclaringType() = this and result.hasName("CompareTo") and result.getNumberOfParameters() = 1 and - result.getParameter(0).getType() instanceof ObjectType and result.getReturnType() instanceof IntType } } @@ -263,7 +262,6 @@ class SystemObjectClass extends SystemClass instanceof ObjectType { result.getDeclaringType() = this and result.hasName("Equals") and result.getNumberOfParameters() = 1 and - result.getParameter(0).getType() instanceof ObjectType and result.getReturnType() instanceof BoolType } @@ -273,8 +271,6 @@ class SystemObjectClass extends SystemClass instanceof ObjectType { result.getDeclaringType() = this and result.hasName("Equals") and result.getNumberOfParameters() = 2 and - result.getParameter(0).getType() instanceof ObjectType and - result.getParameter(1).getType() instanceof ObjectType and result.getReturnType() instanceof BoolType } @@ -284,8 +280,6 @@ class SystemObjectClass extends SystemClass instanceof ObjectType { result.getDeclaringType() = this and result.hasName("ReferenceEquals") and result.getNumberOfParameters() = 2 and - result.getParameter(0).getType() instanceof ObjectType and - result.getParameter(1).getType() instanceof ObjectType and result.getReturnType() instanceof BoolType } diff --git a/csharp/ql/test/library-tests/frameworks/system/Equals/Equals.cs b/csharp/ql/test/library-tests/frameworks/system/Equals/Equals.cs index 504c81198355..02b9e6b340a6 100644 --- a/csharp/ql/test/library-tests/frameworks/system/Equals/Equals.cs +++ b/csharp/ql/test/library-tests/frameworks/system/Equals/Equals.cs @@ -24,3 +24,16 @@ struct Equals1Struct { public override bool Equals(object other) => false; } + +#nullable enable + +class NullableEquals1 +{ + public override bool Equals(object? other) => false; +} + +class NullableEquals2 : IEquatable +{ + public bool Equals(NullableEquals2? other) => other != null; + public override bool Equals(object? other) => other is NullableEquals2 n && Equals(n); +} diff --git a/csharp/ql/test/library-tests/frameworks/system/Equals/Equals.expected b/csharp/ql/test/library-tests/frameworks/system/Equals/Equals.expected index b05c8852b2bb..30dafbad3418 100644 --- a/csharp/ql/test/library-tests/frameworks/system/Equals/Equals.expected +++ b/csharp/ql/test/library-tests/frameworks/system/Equals/Equals.expected @@ -5,3 +5,5 @@ | Equals.cs:16:7:16:13 | Equals3 | Equals3.Equals(Equals3) | true | | Equals.cs:21:8:21:21 | NoEqualsStruct | System.ValueType.Equals(object) | false | | Equals.cs:23:8:23:20 | Equals1Struct | Equals1Struct.Equals(object) | true | +| Equals.cs:31:7:31:21 | NullableEquals1 | NullableEquals1.Equals(object) | true | +| Equals.cs:36:7:36:21 | NullableEquals2 | NullableEquals2.Equals(NullableEquals2) | true | diff --git a/csharp/ql/test/query-tests/API Abuse/ClassDoesNotImplementEquals/NullableTest.cs b/csharp/ql/test/query-tests/API Abuse/ClassDoesNotImplementEquals/NullableTest.cs new file mode 100644 index 000000000000..81c96940a909 --- /dev/null +++ b/csharp/ql/test/query-tests/API Abuse/ClassDoesNotImplementEquals/NullableTest.cs @@ -0,0 +1,46 @@ +using System; + +#nullable enable + +namespace Test +{ + class TestClass1 : IEquatable + { + private int field1; + + public bool Equals(TestClass1? param1) + { + return param1 != null && field1 == param1.field1; + } + + public override bool Equals(object? param2) + { + return param2 is TestClass1 tc && Equals(tc); + } + + public override int GetHashCode() + { + return field1; + } + } + + class TestClass2 + { + private string field2; + + public TestClass2(string s) + { + field2 = s; + } + + public override bool Equals(object? param3) + { + return param3 is TestClass2 tc && field2 == tc.field2; + } + + public override int GetHashCode() + { + return field2?.GetHashCode() ?? 0; + } + } +}