From 3ebd08ec67c37dd1a75ec566a89a61de64f48196 Mon Sep 17 00:00:00 2001 From: Shani Elharrar Date: Wed, 7 Dec 2022 15:28:24 +0200 Subject: [PATCH] Assembly - Run the keep rule on Uber jar... Instead of running them jar by jar, this makes the keep rule much more usable. --- src/main/scala/sbtassembly/Assembly.scala | 70 ++++++++++++++++++----- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/src/main/scala/sbtassembly/Assembly.scala b/src/main/scala/sbtassembly/Assembly.scala index b0f9e18a..033bb1e9 100644 --- a/src/main/scala/sbtassembly/Assembly.scala +++ b/src/main/scala/sbtassembly/Assembly.scala @@ -3,27 +3,30 @@ package sbtassembly import com.eed3si9n.jarjarabrams._ import sbt.Def.Initialize import sbt.Keys._ -import sbt.Package.{ manifestFormat, JarManifest, MainClass, ManifestAttributes } +import sbt.Package.{JarManifest, MainClass, ManifestAttributes, manifestFormat} import sbt.internal.util.HListFormats._ import sbt.internal.util.HNil import sbt.internal.util.Types.:+: -import sbt.io.{ DirectoryFilter => _, IO => _, Path => _, Using } +import sbt.io.{Using, DirectoryFilter => _, IO => _, Path => _} import sbt.util.FileInfo.lastModified -import sbt.util.Tracked.{ inputChanged, lastOutput } -import sbt.util.{ FilesInfo, Level, ModifiedFileInfo } -import sbt.{ File, Logger, _ } +import sbt.util.Tracked.{inputChanged, lastOutput} +import sbt.util.{FilesInfo, Level, ModifiedFileInfo} +import sbt.{File, Logger, _} import sbt.Tags.Tag import CacheImplicits._ -import sbtassembly.AssemblyPlugin.autoImport.{ Assembly => _, _ } +import com.eed3si9n.jarjar.util.{EntryStruct, IoUtil} +import com.eed3si9n.jarjar.{JJProcessor, Keep} +import org.apache.logging.log4j.core.util.IOUtils +import sbtassembly.AssemblyPlugin.autoImport.{Assembly => _, _} import sbtassembly.PluginCompat.ClasspathUtilities import java.io._ import java.net.URI -import java.nio.file.attribute.{ BasicFileAttributeView, FileTime, PosixFilePermission } -import java.nio.file.{ Path, _ } +import java.nio.file.attribute.{BasicFileAttributeView, FileTime, PosixFilePermission} +import java.nio.file.{Path, _} import java.security.MessageDigest import java.time.Instant -import java.util.jar.{ Attributes => JAttributes, JarFile, Manifest => JManifest } +import java.util.jar.{JarFile, Attributes => JAttributes, Manifest => JManifest} import scala.annotation.tailrec import scala.collection.GenSeq import scala.collection.JavaConverters._ @@ -260,7 +263,12 @@ object Assembly { } } - val classShader = shader(ao.shadeRules.filter(_.isApplicableToCompiling), log) + val classShader = shader(ao.shadeRules.filter(_.isApplicableToCompiling).filter(x => x.shadePattern match { + case ShadePattern.Rename(_) => true + case ShadePattern.Zap(_) => true + case ShadePattern.Keep(_) => false + }), log) + val classByParentDir = if (!ao.includeBin) Vector.empty[(File, File)] else dirs.flatMap(dir => (dir.data ** (-DirectoryFilter)).get.map(dir.data -> _)) @@ -336,12 +344,22 @@ object Assembly { timed(Level.Debug, "Finding remaining conflicts that were not merged") { reportConflictsMissedByTheMerge(mergedEntries, log) } - val jarEntriesToWrite = timed(Level.Debug, "Sort/Parallelize merged entries") { - if (ao.repeatableBuild) // we need the jars in a specific order to have a consistent hash - mergedEntries.flatMap(_.entries).seq.sortBy(_.target) - else // we actually gain performance when creating the jar in parallel, but we won't have a consistent hash - mergedEntries.flatMap(_.entries).par + + val jarEntriesToWrite = { + val temp = mergedEntries.flatMap(_.entries) + + val withKeepRule = timed(Level.Debug, "Keep") { + keepShader(ao.shadeRules, log, temp) + } + + timed(Level.Debug, "Sort/Parallelize merged entries") { + if (ao.repeatableBuild) // we need the jars in a specific order to have a consistent hash + withKeepRule.sortBy(_.target) + else // we actually gain performance when creating the jar in parallel, but we won't have a consistent hash + withKeepRule.par + } } + val localTime = timestamp .map(t => t - java.util.TimeZone.getDefault.getOffset(t)) .getOrElse(System.currentTimeMillis()) @@ -481,6 +499,28 @@ object Assembly { def isScalaLibraryFile(scalaLibraries: Vector[String], file: File): Boolean = scalaLibraries exists { x => file.getName startsWith x } + private[sbtassembly] def keepShader(shadeRules: SeqShadeRules, log: Logger, entries: Seq[JarEntry]): Seq[JarEntry] = { + val jjRules = shadeRules.map(_.shadePattern).collect({ + case ShadePattern.Keep(patterns) => + patterns.map(pattern => { + val jrule = new Keep() + jrule.setPattern(pattern) + jrule + }) + }).flatten + + val proc = new JJProcessor(jjRules, false, true, null) + entries.foreach({ entry => + val entryStruct = new EntryStruct() + entryStruct.name = entry.target + entryStruct.data = Streamable.bytes(entry.stream()) + proc.process(entryStruct) + }) + + val itemsToExclude = proc.getExcludes + entries.filterNot(entry => itemsToExclude.contains(entry.target)) + } + private[sbtassembly] def shader( shadeRules: SeqShadeRules, log: Logger