Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DXIL] Adding support to RootSignatureFlags in obj2yaml #122396

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

joaosaffran
Copy link
Contributor

@joaosaffran joaosaffran commented Jan 10, 2025

This PR adds:

  • RootSignatureFlags extraction from DXContainer using obj2yaml

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-objectyaml

@llvm/pr-subscribers-backend-directx

Author: None (joaosaffran)

Changes

This PR adds:

  • Root signature 1.0 definition for RootSignatureFlags
  • Root Signature Generation to DX Container
  • Root Signature RootSignatureFlags extraction from LLVM
  • Root Signature generation to DXIL IR
  • RootSignatureFlags Validation
  • RootSignatureFlags extraction from DXContainer using obj2yaml

Patch is 25.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122396.diff

14 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILMetadataAnalysis.h (+3)
  • (added) llvm/include/llvm/Analysis/DXILRootSignature.h (+88)
  • (modified) llvm/include/llvm/BinaryFormat/DXContainerConstants.def (+1)
  • (modified) llvm/include/llvm/Object/DXContainer.h (+8)
  • (modified) llvm/include/llvm/ObjectYAML/DXContainerYAML.h (+14)
  • (modified) llvm/lib/Analysis/CMakeLists.txt (+1)
  • (modified) llvm/lib/Analysis/DXILMetadataAnalysis.cpp (+17)
  • (added) llvm/lib/Analysis/DXILRootSignature.cpp (+110)
  • (modified) llvm/lib/Object/DXContainer.cpp (+15)
  • (modified) llvm/lib/ObjectYAML/DXContainerEmitter.cpp (+7)
  • (modified) llvm/lib/ObjectYAML/DXContainerYAML.cpp (+68)
  • (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+23)
  • (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignatures/FlagsElement.ll (+27)
  • (modified) llvm/tools/obj2yaml/dxcontainer2yaml.cpp (+22)
diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index cb535ac14f1c61..89c5bffcdbb954 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -10,10 +10,12 @@
 #define LLVM_ANALYSIS_DXILMETADATA_H
 
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/VersionTuple.h"
 #include "llvm/TargetParser/Triple.h"
+#include <optional>
 
 namespace llvm {
 
@@ -37,6 +39,7 @@ struct ModuleMetadataInfo {
   Triple::EnvironmentType ShaderProfile{Triple::UnknownEnvironment};
   VersionTuple ValidatorVersion{};
   SmallVector<EntryProperties> EntryPropertyVec{};
+  std::optional<root_signature::VersionedRootSignatureDesc> RootSignatureDesc;
   void print(raw_ostream &OS) const;
 };
 
diff --git a/llvm/include/llvm/Analysis/DXILRootSignature.h b/llvm/include/llvm/Analysis/DXILRootSignature.h
new file mode 100644
index 00000000000000..cb3d6192f4404d
--- /dev/null
+++ b/llvm/include/llvm/Analysis/DXILRootSignature.h
@@ -0,0 +1,88 @@
+//===- DXILRootSignature.h - DXIL Root Signature helper objects -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with DXIL Root
+/// Signatures.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_DIRECTX_HLSLROOTSIGNATURE_H
+#define LLVM_DIRECTX_HLSLROOTSIGNATURE_H
+
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
+namespace llvm {
+namespace dxil {
+namespace root_signature {
+
+enum class RootSignatureElementKind {
+  None = 0,
+  RootFlags = 1,
+  RootConstants = 2,
+  RootDescriptor = 3,
+  DescriptorTable = 4,
+  StaticSampler = 5
+};
+
+enum class RootSignatureVersion {
+  Version_1 = 1,
+  Version_1_0 = 1,
+  Version_1_1 = 2,
+  Version_1_2 = 3
+};
+
+enum RootSignatureFlags : uint32_t {
+  None = 0,
+  AllowInputAssemblerInputLayout = 0x1,
+  DenyVertexShaderRootAccess = 0x2,
+  DenyHullShaderRootAccess = 0x4,
+  DenyDomainShaderRootAccess = 0x8,
+  DenyGeometryShaderRootAccess = 0x10,
+  DenyPixelShaderRootAccess = 0x20,
+  AllowStreamOutput = 0x40,
+  LocalRootSignature = 0x80,
+  DenyAmplificationShaderRootAccess = 0x100,
+  DenyMeshShaderRootAccess = 0x200,
+  CBVSRVUAVHeapDirectlyIndexed = 0x400,
+  SamplerHeapDirectlyIndexed = 0x800,
+  AllowLowTierReservedHwCbLimit = 0x80000000,
+  ValidFlags = 0x80000fff
+};
+
+struct DxilRootSignatureDesc1_0 {
+  RootSignatureFlags Flags;
+};
+
+struct VersionedRootSignatureDesc {
+  RootSignatureVersion Version;
+  union {
+    DxilRootSignatureDesc1_0 Desc_1_0;
+  };
+
+  bool isPopulated();
+
+  void swapBytes();
+};
+
+class MetadataParser {
+public:
+  NamedMDNode *Root;
+  MetadataParser(NamedMDNode *Root) : Root(Root) {}
+
+  bool Parse(RootSignatureVersion Version, VersionedRootSignatureDesc *Desc);
+
+private:
+  bool ParseRootFlags(MDNode *RootFlagRoot, VersionedRootSignatureDesc *Desc);
+  bool ParseRootSignatureElement(MDNode *Element,
+                                 VersionedRootSignatureDesc *Desc);
+};
+} // namespace root_signature
+} // namespace dxil
+} // namespace llvm
+
+#endif // LLVM_DIRECTX_HLSLROOTSIGNATURE_H
diff --git a/llvm/include/llvm/BinaryFormat/DXContainerConstants.def b/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
index 1aacbb2f65b27f..38b69228cd3975 100644
--- a/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
+++ b/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
@@ -4,6 +4,7 @@ CONTAINER_PART(DXIL)
 CONTAINER_PART(SFI0)
 CONTAINER_PART(HASH)
 CONTAINER_PART(PSV0)
+CONTAINER_PART(RTS0)
 CONTAINER_PART(ISG1)
 CONTAINER_PART(OSG1)
 CONTAINER_PART(PSG1)
diff --git a/llvm/include/llvm/Object/DXContainer.h b/llvm/include/llvm/Object/DXContainer.h
index 19c83ba6c6e85d..9a6aa8224eddf4 100644
--- a/llvm/include/llvm/Object/DXContainer.h
+++ b/llvm/include/llvm/Object/DXContainer.h
@@ -17,6 +17,7 @@
 
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/MemoryBufferRef.h"
@@ -287,6 +288,7 @@ class DXContainer {
   std::optional<uint64_t> ShaderFeatureFlags;
   std::optional<dxbc::ShaderHash> Hash;
   std::optional<DirectX::PSVRuntimeInfo> PSVInfo;
+  std::optional<dxil::root_signature::VersionedRootSignatureDesc> RootSignature;
   DirectX::Signature InputSignature;
   DirectX::Signature OutputSignature;
   DirectX::Signature PatchConstantSignature;
@@ -296,6 +298,7 @@ class DXContainer {
   Error parseDXILHeader(StringRef Part);
   Error parseShaderFeatureFlags(StringRef Part);
   Error parseHash(StringRef Part);
+  Error parseRootSignature(StringRef Part);
   Error parsePSVInfo(StringRef Part);
   Error parseSignature(StringRef Part, DirectX::Signature &Array);
   friend class PartIterator;
@@ -382,6 +385,11 @@ class DXContainer {
 
   std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; }
 
+  std::optional<dxil::root_signature::VersionedRootSignatureDesc>
+  getRootSignature() const {
+    return RootSignature;
+  }
+
   const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const {
     return PSVInfo;
   };
diff --git a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
index 66ad057ab0e30f..e9da51f61c0a2b 100644
--- a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
+++ b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
@@ -16,6 +16,7 @@
 #define LLVM_OBJECTYAML_DXCONTAINERYAML_H
 
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/ObjectYAML/YAML.h"
 #include "llvm/Support/YAMLTraits.h"
@@ -149,6 +150,13 @@ struct Signature {
   llvm::SmallVector<SignatureParameter> Parameters;
 };
 
+struct RootSignature {
+  RootSignature() = default;
+
+  dxil::root_signature::RootSignatureVersion Version;
+  dxil::root_signature::RootSignatureFlags Flags;
+};
+
 struct Part {
   Part() = default;
   Part(std::string N, uint32_t S) : Name(N), Size(S) {}
@@ -159,6 +167,7 @@ struct Part {
   std::optional<ShaderHash> Hash;
   std::optional<PSVInfo> Info;
   std::optional<DXContainerYAML::Signature> Signature;
+  std::optional<DXContainerYAML::RootSignature> RootSignature;
 };
 
 struct Object {
@@ -241,6 +250,11 @@ template <> struct MappingTraits<DXContainerYAML::Signature> {
   static void mapping(IO &IO, llvm::DXContainerYAML::Signature &El);
 };
 
+template <> struct MappingTraits<DXContainerYAML::RootSignature> {
+  static void mapping(IO &IO,
+                      llvm::DXContainerYAML::RootSignature &RootSignature);
+};
+
 } // namespace yaml
 
 } // namespace llvm
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index 0db5b80f336cb5..8875ddd34fe56c 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -62,6 +62,7 @@ add_llvm_component_library(LLVMAnalysis
   DominanceFrontier.cpp
   DXILResource.cpp
   DXILMetadataAnalysis.cpp
+  DXILRootSignature.cpp
   FunctionPropertiesAnalysis.cpp
   GlobalsModRef.cpp
   GuardUtils.cpp
diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
index a7f666a3f8b48f..3bd60bfe203f49 100644
--- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
+++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
@@ -10,12 +10,15 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <memory>
 
 #define DEBUG_TYPE "dxil-metadata-analysis"
 
@@ -28,6 +31,7 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
   MMDAI.DXILVersion = TT.getDXILVersion();
   MMDAI.ShaderModelVersion = TT.getOSVersion();
   MMDAI.ShaderProfile = TT.getEnvironment();
+
   NamedMDNode *ValidatorVerNode = M.getNamedMetadata("dx.valver");
   if (ValidatorVerNode) {
     auto *ValVerMD = cast<MDNode>(ValidatorVerNode->getOperand(0));
@@ -37,6 +41,19 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
         VersionTuple(MajorMD->getZExtValue(), MinorMD->getZExtValue());
   }
 
+  NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
+  if (RootSignatureNode) {
+    auto RootSignatureParser =
+        root_signature::MetadataParser(RootSignatureNode);
+
+    root_signature::VersionedRootSignatureDesc Desc;
+
+    RootSignatureParser.Parse(root_signature::RootSignatureVersion::Version_1,
+                              &Desc);
+
+    MMDAI.RootSignatureDesc = Desc;
+  }
+
   // For all HLSL Shader functions
   for (auto &F : M.functions()) {
     if (!F.hasFnAttribute("hlsl.shader"))
diff --git a/llvm/lib/Analysis/DXILRootSignature.cpp b/llvm/lib/Analysis/DXILRootSignature.cpp
new file mode 100644
index 00000000000000..fce97eb27cf8f8
--- /dev/null
+++ b/llvm/lib/Analysis/DXILRootSignature.cpp
@@ -0,0 +1,110 @@
+//===- DXILRootSignature.cpp - DXIL Root Signature helper objects
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains the parsing logic to extract root signature data
+///       from LLVM IR metadata.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DXILRootSignature.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <cassert>
+
+namespace llvm {
+namespace dxil {
+
+bool root_signature::MetadataParser::Parse(RootSignatureVersion Version,
+                                           VersionedRootSignatureDesc *Desc) {
+  Desc->Version = Version;
+  bool HasError = false;
+
+  for (unsigned int Sid = 0; Sid < Root->getNumOperands(); Sid++) {
+    // This should be an if, for error handling
+    MDNode *Node = cast<MDNode>(Root->getOperand(Sid));
+
+    // Not sure what use this for...
+    Metadata *Func = Node->getOperand(0).get();
+
+    // This should be an if, for error handling
+    MDNode *Elements = cast<MDNode>(Node->getOperand(1).get());
+
+    for (unsigned int Eid = 0; Eid < Elements->getNumOperands(); Eid++) {
+      MDNode *Element = cast<MDNode>(Elements->getOperand(Eid));
+
+      HasError = HasError || ParseRootSignatureElement(Element, Desc);
+    }
+  }
+  return HasError;
+}
+
+bool root_signature::MetadataParser::ParseRootFlags(
+    MDNode *RootFlagNode, VersionedRootSignatureDesc *Desc) {
+
+  assert(RootFlagNode->getNumOperands() == 2 &&
+         "Invalid format for RootFlag Element");
+  auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1));
+  auto Value = (RootSignatureFlags)Flag->getZExtValue();
+
+  if ((Value & ~RootSignatureFlags::ValidFlags) != RootSignatureFlags::None)
+    return true;
+
+  switch (Desc->Version) {
+
+  case RootSignatureVersion::Version_1:
+    Desc->Desc_1_0.Flags = (RootSignatureFlags)Value;
+    break;
+  case RootSignatureVersion::Version_1_1:
+  case RootSignatureVersion::Version_1_2:
+    llvm_unreachable("Not implemented yet");
+    break;
+  }
+  return false;
+}
+
+bool root_signature::MetadataParser::ParseRootSignatureElement(
+    MDNode *Element, VersionedRootSignatureDesc *Desc) {
+  MDString *ElementText = cast<MDString>(Element->getOperand(0));
+
+  assert(ElementText != nullptr && "First preoperty of element is not ");
+
+  RootSignatureElementKind ElementKind =
+      StringSwitch<RootSignatureElementKind>(ElementText->getString())
+          .Case("RootFlags", RootSignatureElementKind::RootFlags)
+          .Case("RootConstants", RootSignatureElementKind::RootConstants)
+          .Case("RootCBV", RootSignatureElementKind::RootDescriptor)
+          .Case("RootSRV", RootSignatureElementKind::RootDescriptor)
+          .Case("RootUAV", RootSignatureElementKind::RootDescriptor)
+          .Case("Sampler", RootSignatureElementKind::RootDescriptor)
+          .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
+          .Case("StaticSampler", RootSignatureElementKind::StaticSampler)
+          .Default(RootSignatureElementKind::None);
+
+  switch (ElementKind) {
+
+  case RootSignatureElementKind::RootFlags: {
+    return ParseRootFlags(Element, Desc);
+    break;
+  }
+
+  case RootSignatureElementKind::RootConstants:
+  case RootSignatureElementKind::RootDescriptor:
+  case RootSignatureElementKind::DescriptorTable:
+  case RootSignatureElementKind::StaticSampler:
+  case RootSignatureElementKind::None:
+    llvm_unreachable("Not Implemented yet");
+    break;
+  }
+
+  return true;
+}
+} // namespace dxil
+} // namespace llvm
diff --git a/llvm/lib/Object/DXContainer.cpp b/llvm/lib/Object/DXContainer.cpp
index 3b1a6203a1f8fc..f50f68df88ec2a 100644
--- a/llvm/lib/Object/DXContainer.cpp
+++ b/llvm/lib/Object/DXContainer.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Object/DXContainer.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Object/Error.h"
 #include "llvm/Support/Alignment.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace llvm;
@@ -92,6 +94,14 @@ Error DXContainer::parseHash(StringRef Part) {
   return Error::success();
 }
 
+Error DXContainer::parseRootSignature(StringRef Part) {
+  dxil::root_signature::VersionedRootSignatureDesc Desc;
+  if (Error Err = readStruct(Part, Part.begin(), Desc))
+    return Err;
+  RootSignature = Desc;
+  return Error::success();
+}
+
 Error DXContainer::parsePSVInfo(StringRef Part) {
   if (PSVInfo)
     return parseFailed("More than one PSV0 part is present in the file");
@@ -192,6 +202,11 @@ Error DXContainer::parsePartOffsets() {
         return Err;
       break;
     case dxbc::PartType::Unknown:
+      break;
+    case dxbc::PartType::RTS0:
+      if (Error Err = parseRootSignature(PartData))
+        return Err;
+
       break;
     }
   }
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index 175f1a12f93145..905d409562ff45 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -11,6 +11,7 @@
 ///
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/MC/DXContainerPSVInfo.h"
 #include "llvm/ObjectYAML/ObjectYAML.h"
@@ -261,6 +262,12 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
     }
     case dxbc::PartType::Unknown:
       break; // Skip any handling for unrecognized parts.
+    case dxbc::PartType::RTS0:
+      if (!P.RootSignature.has_value())
+        continue;
+      OS.write(reinterpret_cast<const char *>(&P.RootSignature),
+               sizeof(dxil::root_signature::VersionedRootSignatureDesc));
+      break;
     }
     uint64_t BytesWritten = OS.tell() - DataStart;
     RollingOffset += BytesWritten;
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 5dee1221b27c01..eab3fcc5936f85 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -13,6 +13,7 @@
 
 #include "llvm/ObjectYAML/DXContainerYAML.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/ScopedPrinter.h"
 
@@ -188,6 +189,12 @@ void MappingTraits<DXContainerYAML::Signature>::mapping(
   IO.mapRequired("Parameters", S.Parameters);
 }
 
+void MappingTraits<DXContainerYAML::RootSignature>::mapping(
+    IO &IO, DXContainerYAML::RootSignature &S) {
+  IO.mapRequired("Version", S.Version);
+  IO.mapRequired("Flags", S.Flags);
+}
+
 void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,
                                                    DXContainerYAML::Part &P) {
   IO.mapRequired("Name", P.Name);
@@ -197,6 +204,7 @@ void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,
   IO.mapOptional("Hash", P.Hash);
   IO.mapOptional("PSVInfo", P.Info);
   IO.mapOptional("Signature", P.Signature);
+  IO.mapOptional("RootSignature", P.RootSignature);
 }
 
 void MappingTraits<DXContainerYAML::Object>::mapping(
@@ -290,6 +298,66 @@ void ScalarEnumerationTraits<dxbc::SigComponentType>::enumeration(
     IO.enumCase(Value, E.Name.str().c_str(), E.Value);
 }
 
+template <>
+struct llvm::yaml::ScalarEnumerationTraits<
+    dxil::root_signature::RootSignatureVersion> {
+  static void enumeration(IO &io,
+                          dxil::root_signature::RootSignatureVersion &Val) {
+    io.enumCase(Val, "1.0",
+                dxil::root_signature::RootSignatureVersion::Version_1);
+    io.enumCase(Val, "1.0",
+                dxil::root_signature::RootSignatureVersion::Version_1_0);
+    io.enumCase(Val, "1.1",
+                dxil::root_signature::RootSignatureVersion::Version_1_1);
+    io.enumCase(Val, "1.2",
+                dxil::root_signature::RootSignatureVersion::Version_1_2);
+  }
+};
+
+template <>
+struct llvm::yaml::ScalarEnumerationTraits<
+    dxil::root_signature::RootSignatureFlags> {
+  static void enumeration(IO &io,
+                          dxil::root_signature::RootSignatureFlags &Val) {
+    io.enumCase(Val, "AllowInputAssemblerInputLayout",
+                dxil::root_signature::RootSignatureFlags::
+                    AllowInputAssemblerInputLayout);
+    io.enumCase(
+        Val, "DenyVertexShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyVertexShaderRootAccess);
+    io.enumCase(
+        Val, "DenyHullShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyHullShaderRootAccess);
+    io.enumCase(
+        Val, "DenyDomainShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyDomainShaderRootAccess);
+    io.enumCase(
+        Val, "DenyGeometryShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyGeometryShaderRootAccess);
+    io.enumCase(
+        Val, "DenyPixelShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyPixelShaderRootAccess);
+    io.enumCase(Val, "AllowStreamOutput",
+                dxil::root_signature::RootSignatureFlags::AllowStreamOutput);
+    io.enumCase(Val, "LocalRootSignature",
+                dxil::root_signature::RootSignatureFlags::LocalRootSignature);
+    io.enumCase(Val, "DenyAmplificationShaderRootAccess",
+                dxil::root_signature::RootSignatureFlags::
+                    DenyAmplificationShaderRootAccess);
+    io.enumCase(
+        Val, "DenyMeshShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyMeshShaderRootAccess);
+    io.enumCase(
+        Val, "CBVSRVUAVHeapDirectlyIndexed",
+        dxil::root_signature::RootSignatureFlags::CBVSRVUAVHeapDirectlyIndexed);
+    io.enumCase(
+        Val, "SamplerHeapDirectlyIndexed",
+        dxil::root_signature::RootSignatureFlags::SamplerHeapDirectlyIndexed);
+    io.enumCase(Val, "AllowLowTierReservedHwCbLimit",
+                dxil::root_signature::RootSignatureFlags::
+                    AllowLowTierReservedHwCbLimit);
+  }
+};
 } // namespace yaml
 
 void DXContainerYAML::PSVInfo::mapInfoForVersion(yaml::IO &IO) {
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 7a0bd6a7c88692..e3174d600e6534 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/Direct...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-llvm-analysis

Author: None (joaosaffran)

Changes

This PR adds:

  • Root signature 1.0 definition for RootSignatureFlags
  • Root Signature Generation to DX Container
  • Root Signature RootSignatureFlags extraction from LLVM
  • Root Signature generation to DXIL IR
  • RootSignatureFlags Validation
  • RootSignatureFlags extraction from DXContainer using obj2yaml

Patch is 25.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122396.diff

14 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILMetadataAnalysis.h (+3)
  • (added) llvm/include/llvm/Analysis/DXILRootSignature.h (+88)
  • (modified) llvm/include/llvm/BinaryFormat/DXContainerConstants.def (+1)
  • (modified) llvm/include/llvm/Object/DXContainer.h (+8)
  • (modified) llvm/include/llvm/ObjectYAML/DXContainerYAML.h (+14)
  • (modified) llvm/lib/Analysis/CMakeLists.txt (+1)
  • (modified) llvm/lib/Analysis/DXILMetadataAnalysis.cpp (+17)
  • (added) llvm/lib/Analysis/DXILRootSignature.cpp (+110)
  • (modified) llvm/lib/Object/DXContainer.cpp (+15)
  • (modified) llvm/lib/ObjectYAML/DXContainerEmitter.cpp (+7)
  • (modified) llvm/lib/ObjectYAML/DXContainerYAML.cpp (+68)
  • (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+23)
  • (added) llvm/test/CodeGen/DirectX/ContainerData/RootSignatures/FlagsElement.ll (+27)
  • (modified) llvm/tools/obj2yaml/dxcontainer2yaml.cpp (+22)
diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index cb535ac14f1c61..89c5bffcdbb954 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -10,10 +10,12 @@
 #define LLVM_ANALYSIS_DXILMETADATA_H
 
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/VersionTuple.h"
 #include "llvm/TargetParser/Triple.h"
+#include <optional>
 
 namespace llvm {
 
@@ -37,6 +39,7 @@ struct ModuleMetadataInfo {
   Triple::EnvironmentType ShaderProfile{Triple::UnknownEnvironment};
   VersionTuple ValidatorVersion{};
   SmallVector<EntryProperties> EntryPropertyVec{};
+  std::optional<root_signature::VersionedRootSignatureDesc> RootSignatureDesc;
   void print(raw_ostream &OS) const;
 };
 
diff --git a/llvm/include/llvm/Analysis/DXILRootSignature.h b/llvm/include/llvm/Analysis/DXILRootSignature.h
new file mode 100644
index 00000000000000..cb3d6192f4404d
--- /dev/null
+++ b/llvm/include/llvm/Analysis/DXILRootSignature.h
@@ -0,0 +1,88 @@
+//===- DXILRootSignature.h - DXIL Root Signature helper objects -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with DXIL Root
+/// Signatures.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_DIRECTX_HLSLROOTSIGNATURE_H
+#define LLVM_DIRECTX_HLSLROOTSIGNATURE_H
+
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
+namespace llvm {
+namespace dxil {
+namespace root_signature {
+
+enum class RootSignatureElementKind {
+  None = 0,
+  RootFlags = 1,
+  RootConstants = 2,
+  RootDescriptor = 3,
+  DescriptorTable = 4,
+  StaticSampler = 5
+};
+
+enum class RootSignatureVersion {
+  Version_1 = 1,
+  Version_1_0 = 1,
+  Version_1_1 = 2,
+  Version_1_2 = 3
+};
+
+enum RootSignatureFlags : uint32_t {
+  None = 0,
+  AllowInputAssemblerInputLayout = 0x1,
+  DenyVertexShaderRootAccess = 0x2,
+  DenyHullShaderRootAccess = 0x4,
+  DenyDomainShaderRootAccess = 0x8,
+  DenyGeometryShaderRootAccess = 0x10,
+  DenyPixelShaderRootAccess = 0x20,
+  AllowStreamOutput = 0x40,
+  LocalRootSignature = 0x80,
+  DenyAmplificationShaderRootAccess = 0x100,
+  DenyMeshShaderRootAccess = 0x200,
+  CBVSRVUAVHeapDirectlyIndexed = 0x400,
+  SamplerHeapDirectlyIndexed = 0x800,
+  AllowLowTierReservedHwCbLimit = 0x80000000,
+  ValidFlags = 0x80000fff
+};
+
+struct DxilRootSignatureDesc1_0 {
+  RootSignatureFlags Flags;
+};
+
+struct VersionedRootSignatureDesc {
+  RootSignatureVersion Version;
+  union {
+    DxilRootSignatureDesc1_0 Desc_1_0;
+  };
+
+  bool isPopulated();
+
+  void swapBytes();
+};
+
+class MetadataParser {
+public:
+  NamedMDNode *Root;
+  MetadataParser(NamedMDNode *Root) : Root(Root) {}
+
+  bool Parse(RootSignatureVersion Version, VersionedRootSignatureDesc *Desc);
+
+private:
+  bool ParseRootFlags(MDNode *RootFlagRoot, VersionedRootSignatureDesc *Desc);
+  bool ParseRootSignatureElement(MDNode *Element,
+                                 VersionedRootSignatureDesc *Desc);
+};
+} // namespace root_signature
+} // namespace dxil
+} // namespace llvm
+
+#endif // LLVM_DIRECTX_HLSLROOTSIGNATURE_H
diff --git a/llvm/include/llvm/BinaryFormat/DXContainerConstants.def b/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
index 1aacbb2f65b27f..38b69228cd3975 100644
--- a/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
+++ b/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
@@ -4,6 +4,7 @@ CONTAINER_PART(DXIL)
 CONTAINER_PART(SFI0)
 CONTAINER_PART(HASH)
 CONTAINER_PART(PSV0)
+CONTAINER_PART(RTS0)
 CONTAINER_PART(ISG1)
 CONTAINER_PART(OSG1)
 CONTAINER_PART(PSG1)
diff --git a/llvm/include/llvm/Object/DXContainer.h b/llvm/include/llvm/Object/DXContainer.h
index 19c83ba6c6e85d..9a6aa8224eddf4 100644
--- a/llvm/include/llvm/Object/DXContainer.h
+++ b/llvm/include/llvm/Object/DXContainer.h
@@ -17,6 +17,7 @@
 
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/MemoryBufferRef.h"
@@ -287,6 +288,7 @@ class DXContainer {
   std::optional<uint64_t> ShaderFeatureFlags;
   std::optional<dxbc::ShaderHash> Hash;
   std::optional<DirectX::PSVRuntimeInfo> PSVInfo;
+  std::optional<dxil::root_signature::VersionedRootSignatureDesc> RootSignature;
   DirectX::Signature InputSignature;
   DirectX::Signature OutputSignature;
   DirectX::Signature PatchConstantSignature;
@@ -296,6 +298,7 @@ class DXContainer {
   Error parseDXILHeader(StringRef Part);
   Error parseShaderFeatureFlags(StringRef Part);
   Error parseHash(StringRef Part);
+  Error parseRootSignature(StringRef Part);
   Error parsePSVInfo(StringRef Part);
   Error parseSignature(StringRef Part, DirectX::Signature &Array);
   friend class PartIterator;
@@ -382,6 +385,11 @@ class DXContainer {
 
   std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; }
 
+  std::optional<dxil::root_signature::VersionedRootSignatureDesc>
+  getRootSignature() const {
+    return RootSignature;
+  }
+
   const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const {
     return PSVInfo;
   };
diff --git a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
index 66ad057ab0e30f..e9da51f61c0a2b 100644
--- a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
+++ b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
@@ -16,6 +16,7 @@
 #define LLVM_OBJECTYAML_DXCONTAINERYAML_H
 
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/ObjectYAML/YAML.h"
 #include "llvm/Support/YAMLTraits.h"
@@ -149,6 +150,13 @@ struct Signature {
   llvm::SmallVector<SignatureParameter> Parameters;
 };
 
+struct RootSignature {
+  RootSignature() = default;
+
+  dxil::root_signature::RootSignatureVersion Version;
+  dxil::root_signature::RootSignatureFlags Flags;
+};
+
 struct Part {
   Part() = default;
   Part(std::string N, uint32_t S) : Name(N), Size(S) {}
@@ -159,6 +167,7 @@ struct Part {
   std::optional<ShaderHash> Hash;
   std::optional<PSVInfo> Info;
   std::optional<DXContainerYAML::Signature> Signature;
+  std::optional<DXContainerYAML::RootSignature> RootSignature;
 };
 
 struct Object {
@@ -241,6 +250,11 @@ template <> struct MappingTraits<DXContainerYAML::Signature> {
   static void mapping(IO &IO, llvm::DXContainerYAML::Signature &El);
 };
 
+template <> struct MappingTraits<DXContainerYAML::RootSignature> {
+  static void mapping(IO &IO,
+                      llvm::DXContainerYAML::RootSignature &RootSignature);
+};
+
 } // namespace yaml
 
 } // namespace llvm
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index 0db5b80f336cb5..8875ddd34fe56c 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -62,6 +62,7 @@ add_llvm_component_library(LLVMAnalysis
   DominanceFrontier.cpp
   DXILResource.cpp
   DXILMetadataAnalysis.cpp
+  DXILRootSignature.cpp
   FunctionPropertiesAnalysis.cpp
   GlobalsModRef.cpp
   GuardUtils.cpp
diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
index a7f666a3f8b48f..3bd60bfe203f49 100644
--- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
+++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
@@ -10,12 +10,15 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <memory>
 
 #define DEBUG_TYPE "dxil-metadata-analysis"
 
@@ -28,6 +31,7 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
   MMDAI.DXILVersion = TT.getDXILVersion();
   MMDAI.ShaderModelVersion = TT.getOSVersion();
   MMDAI.ShaderProfile = TT.getEnvironment();
+
   NamedMDNode *ValidatorVerNode = M.getNamedMetadata("dx.valver");
   if (ValidatorVerNode) {
     auto *ValVerMD = cast<MDNode>(ValidatorVerNode->getOperand(0));
@@ -37,6 +41,19 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
         VersionTuple(MajorMD->getZExtValue(), MinorMD->getZExtValue());
   }
 
+  NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
+  if (RootSignatureNode) {
+    auto RootSignatureParser =
+        root_signature::MetadataParser(RootSignatureNode);
+
+    root_signature::VersionedRootSignatureDesc Desc;
+
+    RootSignatureParser.Parse(root_signature::RootSignatureVersion::Version_1,
+                              &Desc);
+
+    MMDAI.RootSignatureDesc = Desc;
+  }
+
   // For all HLSL Shader functions
   for (auto &F : M.functions()) {
     if (!F.hasFnAttribute("hlsl.shader"))
diff --git a/llvm/lib/Analysis/DXILRootSignature.cpp b/llvm/lib/Analysis/DXILRootSignature.cpp
new file mode 100644
index 00000000000000..fce97eb27cf8f8
--- /dev/null
+++ b/llvm/lib/Analysis/DXILRootSignature.cpp
@@ -0,0 +1,110 @@
+//===- DXILRootSignature.cpp - DXIL Root Signature helper objects
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains the parsing logic to extract root signature data
+///       from LLVM IR metadata.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DXILRootSignature.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <cassert>
+
+namespace llvm {
+namespace dxil {
+
+bool root_signature::MetadataParser::Parse(RootSignatureVersion Version,
+                                           VersionedRootSignatureDesc *Desc) {
+  Desc->Version = Version;
+  bool HasError = false;
+
+  for (unsigned int Sid = 0; Sid < Root->getNumOperands(); Sid++) {
+    // This should be an if, for error handling
+    MDNode *Node = cast<MDNode>(Root->getOperand(Sid));
+
+    // Not sure what use this for...
+    Metadata *Func = Node->getOperand(0).get();
+
+    // This should be an if, for error handling
+    MDNode *Elements = cast<MDNode>(Node->getOperand(1).get());
+
+    for (unsigned int Eid = 0; Eid < Elements->getNumOperands(); Eid++) {
+      MDNode *Element = cast<MDNode>(Elements->getOperand(Eid));
+
+      HasError = HasError || ParseRootSignatureElement(Element, Desc);
+    }
+  }
+  return HasError;
+}
+
+bool root_signature::MetadataParser::ParseRootFlags(
+    MDNode *RootFlagNode, VersionedRootSignatureDesc *Desc) {
+
+  assert(RootFlagNode->getNumOperands() == 2 &&
+         "Invalid format for RootFlag Element");
+  auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1));
+  auto Value = (RootSignatureFlags)Flag->getZExtValue();
+
+  if ((Value & ~RootSignatureFlags::ValidFlags) != RootSignatureFlags::None)
+    return true;
+
+  switch (Desc->Version) {
+
+  case RootSignatureVersion::Version_1:
+    Desc->Desc_1_0.Flags = (RootSignatureFlags)Value;
+    break;
+  case RootSignatureVersion::Version_1_1:
+  case RootSignatureVersion::Version_1_2:
+    llvm_unreachable("Not implemented yet");
+    break;
+  }
+  return false;
+}
+
+bool root_signature::MetadataParser::ParseRootSignatureElement(
+    MDNode *Element, VersionedRootSignatureDesc *Desc) {
+  MDString *ElementText = cast<MDString>(Element->getOperand(0));
+
+  assert(ElementText != nullptr && "First preoperty of element is not ");
+
+  RootSignatureElementKind ElementKind =
+      StringSwitch<RootSignatureElementKind>(ElementText->getString())
+          .Case("RootFlags", RootSignatureElementKind::RootFlags)
+          .Case("RootConstants", RootSignatureElementKind::RootConstants)
+          .Case("RootCBV", RootSignatureElementKind::RootDescriptor)
+          .Case("RootSRV", RootSignatureElementKind::RootDescriptor)
+          .Case("RootUAV", RootSignatureElementKind::RootDescriptor)
+          .Case("Sampler", RootSignatureElementKind::RootDescriptor)
+          .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
+          .Case("StaticSampler", RootSignatureElementKind::StaticSampler)
+          .Default(RootSignatureElementKind::None);
+
+  switch (ElementKind) {
+
+  case RootSignatureElementKind::RootFlags: {
+    return ParseRootFlags(Element, Desc);
+    break;
+  }
+
+  case RootSignatureElementKind::RootConstants:
+  case RootSignatureElementKind::RootDescriptor:
+  case RootSignatureElementKind::DescriptorTable:
+  case RootSignatureElementKind::StaticSampler:
+  case RootSignatureElementKind::None:
+    llvm_unreachable("Not Implemented yet");
+    break;
+  }
+
+  return true;
+}
+} // namespace dxil
+} // namespace llvm
diff --git a/llvm/lib/Object/DXContainer.cpp b/llvm/lib/Object/DXContainer.cpp
index 3b1a6203a1f8fc..f50f68df88ec2a 100644
--- a/llvm/lib/Object/DXContainer.cpp
+++ b/llvm/lib/Object/DXContainer.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Object/DXContainer.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Object/Error.h"
 #include "llvm/Support/Alignment.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace llvm;
@@ -92,6 +94,14 @@ Error DXContainer::parseHash(StringRef Part) {
   return Error::success();
 }
 
+Error DXContainer::parseRootSignature(StringRef Part) {
+  dxil::root_signature::VersionedRootSignatureDesc Desc;
+  if (Error Err = readStruct(Part, Part.begin(), Desc))
+    return Err;
+  RootSignature = Desc;
+  return Error::success();
+}
+
 Error DXContainer::parsePSVInfo(StringRef Part) {
   if (PSVInfo)
     return parseFailed("More than one PSV0 part is present in the file");
@@ -192,6 +202,11 @@ Error DXContainer::parsePartOffsets() {
         return Err;
       break;
     case dxbc::PartType::Unknown:
+      break;
+    case dxbc::PartType::RTS0:
+      if (Error Err = parseRootSignature(PartData))
+        return Err;
+
       break;
     }
   }
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index 175f1a12f93145..905d409562ff45 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -11,6 +11,7 @@
 ///
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/MC/DXContainerPSVInfo.h"
 #include "llvm/ObjectYAML/ObjectYAML.h"
@@ -261,6 +262,12 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
     }
     case dxbc::PartType::Unknown:
       break; // Skip any handling for unrecognized parts.
+    case dxbc::PartType::RTS0:
+      if (!P.RootSignature.has_value())
+        continue;
+      OS.write(reinterpret_cast<const char *>(&P.RootSignature),
+               sizeof(dxil::root_signature::VersionedRootSignatureDesc));
+      break;
     }
     uint64_t BytesWritten = OS.tell() - DataStart;
     RollingOffset += BytesWritten;
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 5dee1221b27c01..eab3fcc5936f85 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -13,6 +13,7 @@
 
 #include "llvm/ObjectYAML/DXContainerYAML.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/ScopedPrinter.h"
 
@@ -188,6 +189,12 @@ void MappingTraits<DXContainerYAML::Signature>::mapping(
   IO.mapRequired("Parameters", S.Parameters);
 }
 
+void MappingTraits<DXContainerYAML::RootSignature>::mapping(
+    IO &IO, DXContainerYAML::RootSignature &S) {
+  IO.mapRequired("Version", S.Version);
+  IO.mapRequired("Flags", S.Flags);
+}
+
 void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,
                                                    DXContainerYAML::Part &P) {
   IO.mapRequired("Name", P.Name);
@@ -197,6 +204,7 @@ void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,
   IO.mapOptional("Hash", P.Hash);
   IO.mapOptional("PSVInfo", P.Info);
   IO.mapOptional("Signature", P.Signature);
+  IO.mapOptional("RootSignature", P.RootSignature);
 }
 
 void MappingTraits<DXContainerYAML::Object>::mapping(
@@ -290,6 +298,66 @@ void ScalarEnumerationTraits<dxbc::SigComponentType>::enumeration(
     IO.enumCase(Value, E.Name.str().c_str(), E.Value);
 }
 
+template <>
+struct llvm::yaml::ScalarEnumerationTraits<
+    dxil::root_signature::RootSignatureVersion> {
+  static void enumeration(IO &io,
+                          dxil::root_signature::RootSignatureVersion &Val) {
+    io.enumCase(Val, "1.0",
+                dxil::root_signature::RootSignatureVersion::Version_1);
+    io.enumCase(Val, "1.0",
+                dxil::root_signature::RootSignatureVersion::Version_1_0);
+    io.enumCase(Val, "1.1",
+                dxil::root_signature::RootSignatureVersion::Version_1_1);
+    io.enumCase(Val, "1.2",
+                dxil::root_signature::RootSignatureVersion::Version_1_2);
+  }
+};
+
+template <>
+struct llvm::yaml::ScalarEnumerationTraits<
+    dxil::root_signature::RootSignatureFlags> {
+  static void enumeration(IO &io,
+                          dxil::root_signature::RootSignatureFlags &Val) {
+    io.enumCase(Val, "AllowInputAssemblerInputLayout",
+                dxil::root_signature::RootSignatureFlags::
+                    AllowInputAssemblerInputLayout);
+    io.enumCase(
+        Val, "DenyVertexShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyVertexShaderRootAccess);
+    io.enumCase(
+        Val, "DenyHullShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyHullShaderRootAccess);
+    io.enumCase(
+        Val, "DenyDomainShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyDomainShaderRootAccess);
+    io.enumCase(
+        Val, "DenyGeometryShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyGeometryShaderRootAccess);
+    io.enumCase(
+        Val, "DenyPixelShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyPixelShaderRootAccess);
+    io.enumCase(Val, "AllowStreamOutput",
+                dxil::root_signature::RootSignatureFlags::AllowStreamOutput);
+    io.enumCase(Val, "LocalRootSignature",
+                dxil::root_signature::RootSignatureFlags::LocalRootSignature);
+    io.enumCase(Val, "DenyAmplificationShaderRootAccess",
+                dxil::root_signature::RootSignatureFlags::
+                    DenyAmplificationShaderRootAccess);
+    io.enumCase(
+        Val, "DenyMeshShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyMeshShaderRootAccess);
+    io.enumCase(
+        Val, "CBVSRVUAVHeapDirectlyIndexed",
+        dxil::root_signature::RootSignatureFlags::CBVSRVUAVHeapDirectlyIndexed);
+    io.enumCase(
+        Val, "SamplerHeapDirectlyIndexed",
+        dxil::root_signature::RootSignatureFlags::SamplerHeapDirectlyIndexed);
+    io.enumCase(Val, "AllowLowTierReservedHwCbLimit",
+                dxil::root_signature::RootSignatureFlags::
+                    AllowLowTierReservedHwCbLimit);
+  }
+};
 } // namespace yaml
 
 void DXContainerYAML::PSVInfo::mapInfoForVersion(yaml::IO &IO) {
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 7a0bd6a7c88692..e3174d600e6534 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/Direct...
[truncated]

Copy link

github-actions bot commented Jan 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the past when we've made changes to support emitting new parts of the container file we've broken the changes up into multiple PRs that land separately. Usually the first PR adds parsing and YAML tooling (with YAML roundtrip tests), and the second PR adds support for generating the data from IR and emitting the final file.

I believe this PR would benefit from being split up that way to make it easier to review.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot about this PR that needs to change. I'll work through writing detailed feedback next week but some high level things:

  1. Splitting up the change will make it a lot easier to review (I commented this separately too).
  2. There are no tests of the yaml2obj support.
  3. It doesn't look like you're handling endianness on the encoding.
  4. The code structure is all off.

ObjectYAML cannot and should not depend on the IR analysis library. The code for encoding binary structures should be part of the MC library not the analysis library. Take a look at how the code for the PSV0 data is structured because this should be closer to that.

@joaosaffran joaosaffran marked this pull request as draft January 11, 2025 00:14
@joaosaffran joaosaffran force-pushed the roosignatures/backend branch from 988e7b0 to 8adb678 Compare January 13, 2025 22:10
@joaosaffran joaosaffran changed the title [DXIL] Adding support to RootSignatureFlags generation to DXContainer [DXIL] Adding support to RootSignatureFlags in obj2yaml Jan 13, 2025
Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely moving in the right direction.

void MappingTraits<DXContainerYAML::RootSignatureDesc>::mapping(
IO &IO, DXContainerYAML::RootSignatureDesc &S) {
IO.mapRequired("Version", S.Version);
IO.mapRequired("Flags", S.Flags);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might look at how we did the ShaderFeatureFlags. That approach makes the flags print nicer and more human readable in the YAML output.

In going that route the one thing I would probably do differently is making each flag "optional" rather than required in the yaml with a default value as false. That will make the printing more concise.

FileSize: 1672
PartCount: 7
PartOffsets: [ 60, 1496, 1512, 1540, 1556, 1572, 1588 ]
Parts:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is valid to have a root signature in a container by itself, so you shouldn't need a bunch of different parts. This test is an example of generating a root-signature-only container:

https://github.com/microsoft/DirectXShaderCompiler/blob/main/tools/clang/test/DXC/local_rs.hlsl

@joaosaffran joaosaffran marked this pull request as ready for review January 14, 2025 23:43
Copy link
Contributor

@inbelic inbelic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a collection of nits

if (RS && RS.has_value())
NewPart.RootSignature = DXContainerYAML::RootSignatureDesc(*RS);
break;
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
break;

@@ -153,6 +153,12 @@ dumpDXContainer(MemoryBufferRef Source) {
break;
case dxbc::PartType::Unknown:
break;
case dxbc::PartType::RTS0:
std::optional<dxbc::RootSignatureDesc> RS = Container.getRootSignature();
if (RS && RS.has_value())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (RS && RS.has_value())
if (RS.has_value())

@@ -152,6 +161,11 @@ enum class FeatureFlags : uint64_t {
static_assert((uint64_t)FeatureFlags::NextUnusedBit <= 1ull << 63,
"Shader flag bits exceed enum size.");

#define ROOT_ELEMENT_FLAG(Num, Val) Val = 1ull << Num,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, I think it would be best to just define the values as hexadecimal literals rather than doing bit shifts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see now that you are following the convention set elsewhere in the file.

case dxbc::PartType::RTS0:
if (Error Err = parseRootSignature(PartData))
return Err;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

@@ -52,6 +53,25 @@ SHADER_FEATURE_FLAG(31, 36, NextUnusedBit, "Next reserved shader flag bit (not a
#undef SHADER_FEATURE_FLAG
#endif // SHADER_FEATURE_FLAG

#ifdef ROOT_ELEMENT_FLAG

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add a similar description as we have in the other enums

@@ -0,0 +1,29 @@
//===- llvm/MC/DXContainerRootSignature.cpp - DXContainer RootSignature -*- C++
//-------*-===//
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like clang format wrapped your header.

sys::swapByteOrder(Flags);
}

OS.write(reinterpret_cast<const char *>(this), SizeInfo);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're writing the size without byte-swapping, and byte swapping the fields (in place) without writing them.

This seems wrong. Wouldn't it be better to just write:

uint32_t SizeInfo = sizeof(this);
support::endian::write(OS, SizeInfo, llvm::endianness::little);
support::endian::write(OS, Version, llvm::endianness::little);
support::endian::write(OS, Flags, llvm::endianness::little);

void MappingTraits<DXContainerYAML::RootSignatureDesc>::mapping(
IO &IO, DXContainerYAML::RootSignatureDesc &S) {
IO.mapRequired("Version", S.Version);
#define ROOT_ELEMENT_FLAG(Num, Val, Str) IO.mapRequired(#Val, S.Val);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could instead consider doing this as a mapOptional with false as the default value. That would make it more concise in the YAML representation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

4 participants