From 723a8e772bbdcd6f6aa6313bb6d3e473a927db13 Mon Sep 17 00:00:00 2001 From: Xavier Pinho Date: Thu, 23 Jan 2025 07:34:39 +0000 Subject: [PATCH] [c#] methodFullName for extension method calls (#5245) --- .../AstForExpressionsCreator.scala | 13 ++- .../datastructures/CSharpScope.scala | 48 ++++++++ .../querying/ast/ExtensionMethodTests.scala | 108 ++++++++++++++++++ 3 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala index 947282524cf5..e0d5b6ae2f1e 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -245,7 +245,18 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) { val arguments = astForArgumentList(argumentList, baseTypeFullName) val argTypes = arguments.map(getTypeFullNameFromAstNode).toList val methodMetaData = scope.tryResolveMethodInvocation(callName, argTypes, baseTypeFullName) - (receiverAst.headOption, baseTypeFullName, methodMetaData, arguments) + + // If the instance lookup has failed, we try to look for an extension method. + val instanceLookupResult = (receiverAst.headOption, baseTypeFullName, methodMetaData, arguments) + if (methodMetaData.isEmpty) { + scope.tryResolveExtensionMethodInvocation(baseTypeFullName, callName, argTypes) match { + case Some((methodMetaData, methodClassFullName)) => + (receiverAst.headOption, Some(methodClassFullName), Some(methodMetaData), arguments) + case None => instanceLookupResult + } + } else { + instanceLookupResult + } case IdentifierName | MemberBindingExpression => // This is when a call is made directly, which could also be made from a static import val argTypes = astForArgumentList(argumentList).map(getTypeFullNameFromAstNode).toList diff --git a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala index a31f78ee9633..d48eba93419f 100644 --- a/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala +++ b/joern-cli/frontends/csharpsrc2cpg/src/main/scala/io/joern/csharpsrc2cpg/datastructures/CSharpScope.scala @@ -4,8 +4,10 @@ import io.joern.x2cpg.Defines import io.joern.x2cpg.datastructures.{OverloadableScope, Scope, ScopeElement, TypedScope, TypedScopeElement} import io.joern.x2cpg.utils.ListUtils.singleOrNone import io.shiftleft.codepropertygraph.generated.nodes.DeclarationNew +import io.joern.x2cpg.utils.ListUtils.singleOrNone import scala.collection.mutable +import scala.reflect.ClassTag class CSharpScope(summary: CSharpProgramSummary) extends Scope[String, DeclarationNew, TypedScopeElement] @@ -118,4 +120,50 @@ class CSharpScope(summary: CSharpProgramSummary) Option(top) } + /** Reduces [[typesInScope]] to contain only those types holding an extension method with the desired signature. + */ + private def extensionsInScopeFor( + extendedType: String, + callName: String, + argTypes: List[String] + ): mutable.Set[CSharpType] = { + typesInScope + .map(t => t.copy(methods = t.methods.filter(matchingExtensionMethod(extendedType, callName, argTypes)))) + .filter(_.methods.nonEmpty) + } + + /** Builds a predicate for matching [[CSharpMethod]] with an ad-hoc description of theirs. + */ + private def matchingExtensionMethod( + thisType: String, + name: String, + argTypes: List[String] + ): CSharpMethod => Boolean = { m => + m.isStatic && m.name == name && m.parameterTypes.map(_._2) == thisType :: argTypes + } + + /** Tries to find an extension method for [[baseTypeFullName]] with the given [[callName]] and [[argTypes]] in the + * types currently in scope. + * + * @param baseTypeFullName + * the extension method's `this` argument. + * @param callName + * the method name + * @param argTypes + * the method's argument types, excluding `this` + * @return + * the method metadata, together with the class name where it can be found + */ + def tryResolveExtensionMethodInvocation( + baseTypeFullName: Option[String], + callName: String, + argTypes: List[String] + ): Option[(CSharpMethod, String)] = { + baseTypeFullName.flatMap { tfn => + extensionsInScopeFor(tfn, callName, argTypes).take(2).toList match { + case x :: Nil => Some((x.methods.head, x.name)) + case _ => None + } + } + } } diff --git a/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala new file mode 100644 index 000000000000..0741a9776e71 --- /dev/null +++ b/joern-cli/frontends/csharpsrc2cpg/src/test/scala/io/joern/csharpsrc2cpg/querying/ast/ExtensionMethodTests.scala @@ -0,0 +1,108 @@ +package io.joern.csharpsrc2cpg.querying.ast + +import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture +import io.shiftleft.codepropertygraph.generated.ModifierTypes +import io.shiftleft.codepropertygraph.generated.nodes.Identifier +import io.shiftleft.semanticcpg.language.* + +class ExtensionMethodTests extends CSharpCode2CpgFixture { + + "nullary extension-method declaration" should { + val cpg = code(""" + |class MyClass {} + |static class Extensions + |{ + | public static void DoStuff(this MyClass myClass) {} + |} + |""".stripMargin) + + "have correct properties" in { + inside(cpg.method.nameExact("DoStuff").l) { + case doStuff :: Nil => + doStuff.fullName shouldBe "Extensions.DoStuff:void(MyClass)" + doStuff.signature shouldBe "void(MyClass)" + doStuff.methodReturn.typeFullName shouldBe "void" + doStuff.modifier.modifierType.toSet shouldBe Set(ModifierTypes.STATIC, ModifierTypes.PUBLIC) + case xs => fail(s"Expected single DoStuff method, but got $xs") + } + } + + "have correct parameters" in { + inside(cpg.method.nameExact("DoStuff").parameter.sortBy(_.index).l) { + case myClass :: Nil => + myClass.typeFullName shouldBe "MyClass" + myClass.code shouldBe "this MyClass myClass" + myClass.name shouldBe "myClass" + case xs => fail(s"Expected single parameter, but got $xs") + } + } + } + + "nullary extension-method call" should { + val cpg = code(""" + |var x = new MyClass(); + |x.DoStuff(); + | + |class MyClass {} + |static class Extensions + |{ + | public static void DoStuff(this MyClass myClass) {} + |} + |""".stripMargin) + + "have correct properties" in { + inside(cpg.call.nameExact("DoStuff").l) { + case doStuff :: Nil => + doStuff.code shouldBe "x.DoStuff()" + doStuff.methodFullName shouldBe "Extensions.DoStuff:void(MyClass)" + case xs => fail(s"Expected single DoStuff call, but got $xs") + } + } + + "have correct arguments" in { + inside(cpg.call.nameExact("DoStuff").argument.sortBy(_.argumentIndex).l) { + case (x: Identifier) :: Nil => + x.argumentIndex shouldBe 0 + x.name shouldBe "x" + x.typeFullName shouldBe "MyClass" + x.code shouldBe "x" + case xs => fail(s"Expected single identifier argument to DoStuff, but got $xs") + } + } + } + + "two same-named extension methods in different namespaces" should { + val cpg = code(""" + |using Version1; + |var x = new MyClass(); + |x.DoStuff(0); + | + |class MyClass {} + |""".stripMargin) + .moreCode(""" + |namespace Version1; + | + |static class Extension1 + |{ + | public static void DoStuff(this MyClass myClass, int z) {} + |} + |""".stripMargin) + .moreCode(""" + |namespace Version2; + | + |static class Extension2 + |{ + | public static void DoStuff(this MyClass myClass, int z) {} + |} + |""".stripMargin) + + "find the correct extension method" in { + inside(cpg.call.nameExact("DoStuff").l) { + case doStuff :: Nil => + doStuff.code shouldBe "x.DoStuff(0)" + doStuff.methodFullName shouldBe "Version1.Extension1.DoStuff:void(MyClass,System.Int32)" + case xs => fail(s"Expected single DoStuff call, but got $xs") + } + } + } +}