diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala index 5ea541f086e8..88175c9275a6 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala @@ -160,11 +160,6 @@ class CSharpScope(summary: CSharpProgramSummary) callName: String, argTypes: List[String] ): Option[(CSharpMethod, String)] = { - baseTypeFullName.flatMap { tfn => - extensionsInScopeFor(tfn, callName, argTypes).take(2).toList match { - case x :: Nil => Some((x.methods.head, x.name)) - case _ => None - } - } + baseTypeFullName.flatMap(extensionsInScopeFor(_, callName, argTypes).headOption).map(x => (x.methods.head, x.name)) } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala index 0741a9776e71..13f031fef575 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala @@ -1,8 +1,8 @@ package io.joern.csharpsrc2cpg.querying.ast import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.ModifierTypes -import io.shiftleft.codepropertygraph.generated.nodes.Identifier +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, ModifierTypes} +import io.shiftleft.codepropertygraph.generated.nodes.{Call, Identifier} import io.shiftleft.semanticcpg.language.* class ExtensionMethodTests extends CSharpCode2CpgFixture { @@ -55,6 +55,7 @@ class ExtensionMethodTests extends CSharpCode2CpgFixture { case doStuff :: Nil => doStuff.code shouldBe "x.DoStuff()" doStuff.methodFullName shouldBe "Extensions.DoStuff:void(MyClass)" + doStuff.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH case xs => fail(s"Expected single DoStuff call, but got $xs") } } @@ -105,4 +106,136 @@ class ExtensionMethodTests extends CSharpCode2CpgFixture { } } } + + "two same-named extension methods involving explicit sub-types" should { + + "map to the compile-time type (1)" in { + val cpg = code(""" + |var x = new MyConcrete(); + |x.DoStuff(); + | + |abstract class MyAbstract; + |class MyConcrete : MyAbstract; + | + |static class Extensions + |{ + | public static int DoStuff(this MyAbstract myAbstract) => 1; + | public static int DoStuff(this MyConcrete myConcrete) => 2; + |} + |""".stripMargin) + cpg.call.nameExact("DoStuff").methodFullName.l shouldBe List("Extensions.DoStuff:System.Int32(MyConcrete)") + } + + "map to the compile-time type (2)" in { + val cpg = code(""" + |MyAbstract x = new MyConcrete(); + |x.DoStuff(); + | + |abstract class MyAbstract; + |class MyConcrete : MyAbstract; + | + |static class Extensions + |{ + | public static int DoStuff(this MyAbstract myAbstract) => 1; + | public static int DoStuff(this MyConcrete myConcrete) => 2; + |} + |""".stripMargin) + cpg.call.nameExact("DoStuff").methodFullName.l shouldBe List("Extensions.DoStuff:System.Int32(MyAbstract)") + } + } + + "calling an extension method for `List`" should { + + "resolve correctly if the receiver is of type `List`" in { + val cpg = code(""" + |using System.Collections.Generic; + | + |var x = new List(); + |x.DoStuff(); + | + |static class Extensions + |{ + | public static int DoStuff(this List myList) => 1; + |} + |""".stripMargin) + + cpg.call.nameExact("DoStuff").methodFullName.l shouldBe List("Extensions.DoStuff:System.Int32(List)") + } + + "resolve correctly if there's only 1 type-parametric extension for `List`" in { + val cpg = code(""" + |using System.Collections.Generic; + | + |var x = new List(); + |x.DoStuff(); + | + |static class Extensions + |{ + | public static int DoStuff(this List myList) => 1; + |} + |""".stripMargin) + + cpg.call.nameExact("DoStuff").methodFullName.l shouldBe List("Extensions.DoStuff:System.Int32(List)") + } + + // TODO: The two `DoStuff` methods have the same methodFullName. + "resolve correctly if there are 2 possible extensions, one for `List` and another for `List`" ignore { + val cpg = code(""" + |using System.Collections.Generic; + | + |var x = new List(); + |x.DoStuff(); + | + |static class Extensions + |{ + | public static int DoStuff(this List myList) { return 1; } + | public static int DoStuff(this List myList) { return 2; } + |} + |""".stripMargin) + + cpg.call.nameExact("DoStuff").callee.l shouldBe cpg.literal("2").method.l + } + } + + "consecutive unary extension method calls" should { + val cpg = code(""" + |var x = new MyClass(); + |var y = x.Foo().Bar(); + | + |class MyClass {} + |static class Extensions + |{ + | public static MyClass Foo(this MyClass c) => c; + | public static int Bar(this MyClass c) => 1; + |} + |""".stripMargin) + + "have correct properties and arguments" in { + inside(cpg.call.nameExact("Bar").l) { + case bar :: Nil => + bar.code shouldBe "x.Foo().Bar()" + bar.methodFullName shouldBe "Extensions.Bar:System.Int32(MyClass)" + bar.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + inside(bar.argument.sortBy(_.argumentIndex).l) { + case (foo: Call) :: Nil => + foo.code shouldBe "x.Foo()" + foo.methodFullName shouldBe "Extensions.Foo:MyClass(MyClass)" + foo.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + inside(foo.argument.sortBy(_.argumentIndex).l) { + case (x: Identifier) :: Nil => + x.code shouldBe "x" + x.name shouldBe "x" + x.typeFullName shouldBe "MyClass" + case xs => fail(s"Expected identifier argument to Foo, but got $xs") + } + case xs => fail(s"Expected single call argument to Bar, but got $xs") + } + case xs => fail(s"Expected single call to Bar, but got $xs") + } + } + + "have correct properties for the result of the chained call" in { + cpg.assignment.target.isIdentifier.nameExact("y").typeFullName.l shouldBe List("System.Int32") + } + } }