Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c#] methodFullName for extension method calls #5245

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,18 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
val arguments = astForArgumentList(argumentList, baseTypeFullName)
val argTypes = arguments.map(getTypeFullNameFromAstNode).toList
val methodMetaData = scope.tryResolveMethodInvocation(callName, argTypes, baseTypeFullName)
(receiverAst.headOption, baseTypeFullName, methodMetaData, arguments)

// If the instance lookup has failed, we try to look for an extension method.
val instanceLookupResult = (receiverAst.headOption, baseTypeFullName, methodMetaData, arguments)
if (methodMetaData.isEmpty) {
scope.tryResolveExtensionMethodInvocation(baseTypeFullName, callName, argTypes) match {
case Some((methodMetaData, methodClassFullName)) =>
(receiverAst.headOption, Some(methodClassFullName), Some(methodMetaData), arguments)
case None => instanceLookupResult
}
} else {
instanceLookupResult
}
case IdentifierName | MemberBindingExpression =>
// This is when a call is made directly, which could also be made from a static import
val argTypes = astForArgumentList(argumentList).map(getTypeFullNameFromAstNode).toList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import io.joern.x2cpg.Defines
import io.joern.x2cpg.datastructures.{OverloadableScope, Scope, ScopeElement, TypedScope, TypedScopeElement}
import io.joern.x2cpg.utils.ListUtils.singleOrNone
import io.shiftleft.codepropertygraph.generated.nodes.DeclarationNew
import io.joern.x2cpg.utils.ListUtils.singleOrNone

import scala.collection.mutable
import scala.reflect.ClassTag

class CSharpScope(summary: CSharpProgramSummary)
extends Scope[String, DeclarationNew, TypedScopeElement]
Expand Down Expand Up @@ -118,4 +120,50 @@ class CSharpScope(summary: CSharpProgramSummary)
Option(top)
}

/** Reduces [[typesInScope]] to contain only those types holding an extension method with the desired signature.
*/
private def extensionsInScopeFor(
extendedType: String,
callName: String,
argTypes: List[String]
): mutable.Set[CSharpType] = {
typesInScope
.map(t => t.copy(methods = t.methods.filter(matchingExtensionMethod(extendedType, callName, argTypes))))
.filter(_.methods.nonEmpty)
}

/** Builds a predicate for matching [[CSharpMethod]] with an ad-hoc description of theirs.
*/
private def matchingExtensionMethod(
thisType: String,
name: String,
argTypes: List[String]
): CSharpMethod => Boolean = { m =>
m.isStatic && m.name == name && m.parameterTypes.map(_._2) == thisType :: argTypes
}

/** Tries to find an extension method for [[baseTypeFullName]] with the given [[callName]] and [[argTypes]] in the
* types currently in scope.
*
* @param baseTypeFullName
* the extension method's `this` argument.
* @param callName
* the method name
* @param argTypes
* the method's argument types, excluding `this`
* @return
* the method metadata, together with the class name where it can be found
*/
def tryResolveExtensionMethodInvocation(
baseTypeFullName: Option[String],
callName: String,
argTypes: List[String]
): Option[(CSharpMethod, String)] = {
baseTypeFullName.flatMap { tfn =>
extensionsInScopeFor(tfn, callName, argTypes).take(2).toList match {
case x :: Nil => Some((x.methods.head, x.name))
case _ => None
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package io.joern.csharpsrc2cpg.querying.ast

import io.joern.csharpsrc2cpg.testfixtures.CSharpCode2CpgFixture
import io.shiftleft.codepropertygraph.generated.ModifierTypes
import io.shiftleft.codepropertygraph.generated.nodes.Identifier
import io.shiftleft.semanticcpg.language.*

class ExtensionMethodTests extends CSharpCode2CpgFixture {

"nullary extension-method declaration" should {
val cpg = code("""
|class MyClass {}
|static class Extensions
|{
| public static void DoStuff(this MyClass myClass) {}
|}
|""".stripMargin)

"have correct properties" in {
inside(cpg.method.nameExact("DoStuff").l) {
case doStuff :: Nil =>
doStuff.fullName shouldBe "Extensions.DoStuff:void(MyClass)"
doStuff.signature shouldBe "void(MyClass)"
doStuff.methodReturn.typeFullName shouldBe "void"
doStuff.modifier.modifierType.toSet shouldBe Set(ModifierTypes.STATIC, ModifierTypes.PUBLIC)
case xs => fail(s"Expected single DoStuff method, but got $xs")
}
}

"have correct parameters" in {
inside(cpg.method.nameExact("DoStuff").parameter.sortBy(_.index).l) {
case myClass :: Nil =>
myClass.typeFullName shouldBe "MyClass"
myClass.code shouldBe "this MyClass myClass"
myClass.name shouldBe "myClass"
case xs => fail(s"Expected single parameter, but got $xs")
}
}
}

"nullary extension-method call" should {
val cpg = code("""
|var x = new MyClass();
|x.DoStuff();
|
|class MyClass {}
|static class Extensions
|{
| public static void DoStuff(this MyClass myClass) {}
|}
|""".stripMargin)

"have correct properties" in {
inside(cpg.call.nameExact("DoStuff").l) {
case doStuff :: Nil =>
doStuff.code shouldBe "x.DoStuff()"
doStuff.methodFullName shouldBe "Extensions.DoStuff:void(MyClass)"
case xs => fail(s"Expected single DoStuff call, but got $xs")
}
}

"have correct arguments" in {
inside(cpg.call.nameExact("DoStuff").argument.sortBy(_.argumentIndex).l) {
case (x: Identifier) :: Nil =>
x.argumentIndex shouldBe 0
x.name shouldBe "x"
x.typeFullName shouldBe "MyClass"
x.code shouldBe "x"
case xs => fail(s"Expected single identifier argument to DoStuff, but got $xs")
}
}
}

"two same-named extension methods in different namespaces" should {
val cpg = code("""
|using Version1;
|var x = new MyClass();
|x.DoStuff(0);
|
|class MyClass {}
|""".stripMargin)
.moreCode("""
|namespace Version1;
|
|static class Extension1
|{
| public static void DoStuff(this MyClass myClass, int z) {}
|}
|""".stripMargin)
.moreCode("""
|namespace Version2;
|
|static class Extension2
|{
| public static void DoStuff(this MyClass myClass, int z) {}
|}
|""".stripMargin)

"find the correct extension method" in {
inside(cpg.call.nameExact("DoStuff").l) {
case doStuff :: Nil =>
doStuff.code shouldBe "x.DoStuff(0)"
doStuff.methodFullName shouldBe "Version1.Extension1.DoStuff:void(MyClass,System.Int32)"
case xs => fail(s"Expected single DoStuff call, but got $xs")
}
}
}
}
Loading