diff --git a/lib/diamond/filters.ts b/lib/diamond/filters.ts new file mode 100644 index 00000000..facaa3c0 --- /dev/null +++ b/lib/diamond/filters.ts @@ -0,0 +1,73 @@ +import { FacetCutAction } from './utils'; +import { AddressZero } from '@ethersproject/constants'; + +export interface FacetFilter { + type: string; + action: FacetCutAction; + contract: string; + selectors: string[]; +} + +// returns true if the selector is found in the only or except filters +export function selectorIsFiltered( + only: FacetFilter[], + except: FacetFilter[], + contract: string, + selector: string, +): boolean { + if (only.length > 0) { + // include selectors found in only, exclude all others + return includes(only, contract, selector); + } + + if (except.length > 0) { + // exclude selectors found in except, include all others + return !includes(except, contract, selector); + } + + // if neither only or except are used, then include all selectors + return true; +} + +// returns true if the selector is found in the filters +export function includes( + filters: FacetFilter[], + contract: string, + selector: string, +): boolean { + return filters.some( + (filter) => + [filter.contract, AddressZero].includes(contract) && + filter.selectors.includes(selector), + ); +} + +// validates that different filter types do not contain the same contract +export function validateFilters(only: FacetFilter[], except: FacetFilter[]) { + if (only.length > 0 && except.length > 0) { + only.forEach((o) => { + except.forEach((e) => { + if (o.contract === e.contract) { + throw new Error( + 'only and except filters cannot contain the same contract', + ); + } + }); + }); + } +} + +export function destructureFilters( + filters: FacetFilter[], + action?: FacetCutAction, +) { + const only = filters.filter((f) => f.type === 'only'); + const except = filters.filter((f) => f.type === 'except'); + + if (action !== undefined) { + only.filter((f) => f.action === action); + except.filter((f) => f.action === action); + } + + return { only, except }; +} diff --git a/lib/diamond/utils.ts b/lib/diamond/utils.ts new file mode 100644 index 00000000..284b651e --- /dev/null +++ b/lib/diamond/utils.ts @@ -0,0 +1,296 @@ +import { + FacetFilter, + destructureFilters, + selectorIsFiltered, + validateFilters, +} from './filters'; +import { AddressZero } from '@ethersproject/constants'; +import { Contract, ContractReceipt } from '@ethersproject/contracts'; +import { + IDiamondReadable, + IDiamondWritable, +} from '@solidstate/typechain-types'; + +export enum FacetCutAction { + ADD, + REPLACE, + REMOVE, +} + +export interface Facet { + target: string; + selectors: string[]; +} + +export interface FacetCut extends Facet { + action: FacetCutAction; +} + +// returns a list of function signatures for a contract +export function getFunctionSignatures(contract: Contract): string[] { + return Object.keys(contract.interface.functions); +} + +// returns a list of selectors for a contract +export function getSelectors(contract: Contract): string[] { + return getFunctionSignatures(contract).map((signature) => + contract.interface.getSighash(signature), + ); +} + +// returns a list of Facets for a contract +export function getFacets(contracts: Contract[]): Facet[] { + return contracts.map((contract) => { + return { + target: contract.address, + selectors: getSelectors(contract), + }; + }); +} + +// returns a FacetCut +export function getFacetCut( + target: string, + selectors: string[], + action: number = 0, +): FacetCut { + return { + target: target, + action: action, + selectors: selectors, + }; +} + +// returns true if the selector is found in the facets +export function selectorExistsInFacets( + selector: string, + facets: Facet[], +): boolean { + return facets.some((facet) => facet.selectors.includes(selector)); +} + +// preview FacetCut which adds unregistered selectors +export async function addUnregisteredSelectors( + diamond: IDiamondReadable, + contracts: Contract[], + filters: FacetFilter[] = [], +): Promise { + const { only, except } = destructureFilters(filters, FacetCutAction.ADD); + validateFilters(only, except); + + const diamondFacets: Facet[] = await diamond.facets(); + const facets = getFacets(contracts); + + let selectorsAdded = false; + let facetCuts: FacetCut[] = []; + + // if facet selector is unregistered then it should be added to the diamond. + for (const facet of facets) { + for (const selector of facet.selectors) { + const target = facet.target; + + if ( + target !== diamond.address && + selector.length > 0 && + !selectorExistsInFacets(selector, diamondFacets) && + selectorIsFiltered(only, except, target, selector) + ) { + facetCuts.push( + getFacetCut(facet.target, [selector], FacetCutAction.ADD), + ); + + selectorsAdded = true; + } + } + } + + if (!selectorsAdded) { + throw new Error('No selectors were added to FacetCut'); + } + + return groupFacetCuts(facetCuts); +} + +// preview FacetCut which replaces registered selectors with unregistered selectors +export async function replaceRegisteredSelectors( + diamond: IDiamondReadable, + contracts: Contract[], + filters: FacetFilter[] = [], +): Promise { + const { only, except } = destructureFilters(filters, FacetCutAction.REPLACE); + validateFilters(only, except); + + const diamondFacets: Facet[] = await diamond.facets(); + const facets = getFacets(contracts); + + let selectorsReplaced = false; + let facetCuts: FacetCut[] = []; + + // if a facet selector is registered with a different target address, the target will + // be replaced + for (const facet of facets) { + for (const selector of facet.selectors) { + const target = facet.target; + const oldTarget = await diamond.facetAddress(selector); + + if ( + target != oldTarget && + target != AddressZero && + target != diamond.address && + selector.length > 0 && + selectorExistsInFacets(selector, diamondFacets) && + selectorIsFiltered(only, except, target, selector) + ) { + facetCuts.push(getFacetCut(target, [selector], FacetCutAction.REPLACE)); + + selectorsReplaced = true; + } + } + } + + if (!selectorsReplaced) { + throw new Error('No selectors were replaced in FacetCut'); + } + + return groupFacetCuts(facetCuts); +} + +// preview FacetCut which removes registered selectors +export async function removeRegisteredSelectors( + diamond: IDiamondReadable, + contracts: Contract[], + filters: FacetFilter[] = [], +): Promise { + const { only, except } = destructureFilters(filters, FacetCutAction.REMOVE); + validateFilters(only, except); + + const diamondFacets: Facet[] = await diamond.facets(); + const facets = getFacets(contracts); + + let selectorsRemoved = false; + let facetCuts: FacetCut[] = []; + + // if a registered selector is not found in the facets then it should be removed + // from the diamond + for (const diamondFacet of diamondFacets) { + for (const selector of diamondFacet.selectors) { + const target = diamondFacet.target; + + if ( + target != AddressZero && + target != diamond.address && + selector.length > 0 && + !selectorExistsInFacets(selector, facets) && + selectorIsFiltered(only, except, AddressZero, selector) + ) { + facetCuts.push( + getFacetCut(AddressZero, [selector], FacetCutAction.REMOVE), + ); + + selectorsRemoved = true; + } + } + } + + if (!selectorsRemoved) { + throw new Error('No selectors were removed from FacetCut'); + } + + return groupFacetCuts(facetCuts); +} + +// preview a FacetCut which adds, replaces, or removes selectors, as needed +export async function previewFacetCut( + diamond: IDiamondReadable, + contracts: Contract[], + filters: FacetFilter[] = [], +): Promise { + let addFacetCuts: FacetCut[] = []; + let replaceFacetCuts: FacetCut[] = []; + let removeFacetCuts: FacetCut[] = []; + + try { + addFacetCuts = await addUnregisteredSelectors(diamond, contracts, filters); + } catch (error) { + console.log(`WARNING: ${(error as Error).message}`); + } + + try { + replaceFacetCuts = await replaceRegisteredSelectors( + diamond, + contracts, + filters, + ); + } catch (error) { + console.log(`WARNING: ${(error as Error).message}`); + } + + try { + removeFacetCuts = await removeRegisteredSelectors( + diamond, + contracts, + filters, + ); + } catch (error) { + console.log(`WARNING: ${(error as Error).message}`); + } + + return groupFacetCuts([ + ...addFacetCuts, + ...replaceFacetCuts, + ...removeFacetCuts, + ]); +} + +// executes a DiamondCut using the provided FacetCut +export async function diamondCut( + diamond: IDiamondWritable, + facetCut: FacetCut[], + target: string = AddressZero, + data: string = '0x', +): Promise { + return (await diamond.diamondCut(facetCut, target, data)).wait(); +} + +// groups facet cuts by target address and action type +export function groupFacetCuts(facetCuts: FacetCut[]): FacetCut[] { + const cuts = facetCuts.reduce((acc: FacetCut[], facetCut: FacetCut) => { + if (acc.length == 0) acc.push(facetCut); + + let exists = false; + + acc.forEach((_, i) => { + if ( + acc[i].action == facetCut.action && + acc[i].target == facetCut.target + ) { + acc[i].selectors.push(...facetCut.selectors); + // removes duplicates, if there are any + acc[i].selectors = [...new Set(acc[i].selectors)]; + exists = true; + } + }); + + // push facet cut if it does not already exist + if (!exists) acc.push(facetCut); + + return acc; + }, []); + + let cache: any = {}; + + // checks if selector is used multiple times, emits warning + cuts.forEach((cut) => { + cut.selectors.forEach((selector: string) => { + if (cache[selector]) { + console.log( + `WARNING: selector: ${selector}, target: ${cut.target} is defined in multiple cuts`, + ); + } else { + cache[selector] = true; + } + }); + }); + + return cuts; +} diff --git a/lib/index.ts b/lib/index.ts index e6459ee4..6a4c3b7e 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -2,3 +2,4 @@ export * from './bn_conversion'; export * from './erc20_permit'; export * from './mocha_describe_filter'; export * from './sign_data'; +export * from './diamond/utils'; diff --git a/lib/package.json b/lib/package.json index 3f4ac877..b536e55e 100644 --- a/lib/package.json +++ b/lib/package.json @@ -25,6 +25,8 @@ "tsc-clean": "tsc --build --clean tsconfig.json" }, "dependencies": { + "@ethersproject/constants": "^5.7.0", + "@ethersproject/contracts": "^5.7.0", "eth-permit": "^0.1.10" }, "files": [ diff --git a/spec/proxy/diamond/SolidStateDiamond.behavior.ts b/spec/proxy/diamond/SolidStateDiamond.behavior.ts index 3326604b..aeffbf8c 100644 --- a/spec/proxy/diamond/SolidStateDiamond.behavior.ts +++ b/spec/proxy/diamond/SolidStateDiamond.behavior.ts @@ -24,7 +24,7 @@ import { MockContract, } from '@ethereum-waffle/mock-contract'; import { SignerWithAddress } from '@nomiclabs/hardhat-ethers/signers'; -import { describeFilter } from '@solidstate/library'; +import { describeFilter, FacetCutAction } from '@solidstate/library'; import { ISolidStateDiamond } from '@solidstate/typechain-types'; import { expect } from 'chai'; import { ethers } from 'hardhat'; @@ -116,13 +116,17 @@ export function describeBehaviorOfSolidStateDiamond( const expectedSelectors = []; for (let selector of selectors) { - await instance - .connect(owner) - .diamondCut( - [{ target: facet.address, action: 0, selectors: [selector] }], - ethers.constants.AddressZero, - '0x', - ); + await instance.connect(owner).diamondCut( + [ + { + target: facet.address, + action: FacetCutAction.ADD, + selectors: [selector], + }, + ], + ethers.constants.AddressZero, + '0x', + ); expectedSelectors.push(selector); @@ -152,7 +156,7 @@ export function describeBehaviorOfSolidStateDiamond( await instance .connect(owner) .diamondCut( - [{ target: facet.address, action: 0, selectors }], + [{ target: facet.address, action: FacetCutAction.ADD, selectors }], ethers.constants.AddressZero, '0x', ); @@ -164,7 +168,7 @@ export function describeBehaviorOfSolidStateDiamond( [ { target: ethers.constants.AddressZero, - action: 2, + action: FacetCutAction.REMOVE, selectors: [selector], }, ], @@ -217,7 +221,7 @@ export function describeBehaviorOfSolidStateDiamond( await instance .connect(owner) .diamondCut( - [{ target: facet.address, action: 0, selectors }], + [{ target: facet.address, action: FacetCutAction.ADD, selectors }], ethers.constants.AddressZero, '0x', ); @@ -229,7 +233,7 @@ export function describeBehaviorOfSolidStateDiamond( [ { target: ethers.constants.AddressZero, - action: 2, + action: FacetCutAction.REMOVE, selectors: [selector], }, ], @@ -282,7 +286,7 @@ export function describeBehaviorOfSolidStateDiamond( await instance .connect(owner) .diamondCut( - [{ target: facet.address, action: 0, selectors }], + [{ target: facet.address, action: FacetCutAction.ADD, selectors }], ethers.constants.AddressZero, '0x', ); @@ -294,7 +298,7 @@ export function describeBehaviorOfSolidStateDiamond( [ { target: ethers.constants.AddressZero, - action: 2, + action: FacetCutAction.REMOVE, selectors: [selector], }, ], diff --git a/spec/proxy/diamond/writable/DiamondWritable.behavior.ts b/spec/proxy/diamond/writable/DiamondWritable.behavior.ts index e87447d4..387984c2 100644 --- a/spec/proxy/diamond/writable/DiamondWritable.behavior.ts +++ b/spec/proxy/diamond/writable/DiamondWritable.behavior.ts @@ -1,7 +1,7 @@ import { describeBehaviorOfERC165Base } from '../../../introspection'; import { deployMockContract } from '@ethereum-waffle/mock-contract'; import { SignerWithAddress } from '@nomiclabs/hardhat-ethers/signers'; -import { describeFilter } from '@solidstate/library'; +import { describeFilter, FacetCutAction } from '@solidstate/library'; import { IDiamondWritable } from '@solidstate/typechain-types'; import { expect } from 'chai'; import { ethers } from 'hardhat'; @@ -68,7 +68,7 @@ export function describeBehaviorOfDiamondWritable( const facets: any = [ { target: facet.address, - action: 0, + action: FacetCutAction.ADD, selectors: [ethers.utils.hexlify(ethers.utils.randomBytes(4))], }, ]; @@ -107,13 +107,17 @@ export function describeBehaviorOfDiamondWritable( ); } - await instance - .connect(owner) - .diamondCut( - [{ target: facet.address, action: 0, selectors }], - ethers.constants.AddressZero, - '0x', - ); + await instance.connect(owner).diamondCut( + [ + { + target: facet.address, + action: FacetCutAction.ADD, + selectors, + }, + ], + ethers.constants.AddressZero, + '0x', + ); for (let fn of functions) { // call reverts, but with mock-specific message @@ -130,7 +134,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: ethers.constants.AddressZero, - action: 0, + action: FacetCutAction.ADD, selectors: [ethers.utils.randomBytes(4)], }, ], @@ -147,7 +151,7 @@ export function describeBehaviorOfDiamondWritable( const facetCuts = [ { target: facet.address, - action: 0, + action: FacetCutAction.ADD, selectors: [ethers.utils.randomBytes(4)], }, ]; @@ -176,13 +180,17 @@ export function describeBehaviorOfDiamondWritable( ethers.provider, ); - await instance - .connect(owner) - .diamondCut( - [{ target: facet.address, action: 0, selectors }], - ethers.constants.AddressZero, - '0x', - ); + await instance.connect(owner).diamondCut( + [ + { + target: facet.address, + action: FacetCutAction.ADD, + selectors, + }, + ], + ethers.constants.AddressZero, + '0x', + ); for (let fn of functions) { // call reverts, but with mock-specific message @@ -197,13 +205,17 @@ export function describeBehaviorOfDiamondWritable( expect(facetReplacement[fn]).not.to.be.undefined; } - await instance - .connect(owner) - .diamondCut( - [{ target: facetReplacement.address, action: 1, selectors }], - ethers.constants.AddressZero, - '0x', - ); + await instance.connect(owner).diamondCut( + [ + { + target: facetReplacement.address, + action: FacetCutAction.REPLACE, + selectors, + }, + ], + ethers.constants.AddressZero, + '0x', + ); for (let fn of functions) { // call reverts, but with mock-specific message @@ -220,7 +232,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: ethers.constants.AddressZero, - action: 1, + action: FacetCutAction.REPLACE, selectors: [ethers.utils.randomBytes(4)], }, ], @@ -239,7 +251,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: facet.address, - action: 1, + action: FacetCutAction.REPLACE, selectors: [ethers.utils.randomBytes(4)], }, ], @@ -259,7 +271,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: instance.address, - action: 0, + action: FacetCutAction.ADD, selectors: [selector], }, ], @@ -272,7 +284,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: facet.address, - action: 1, + action: FacetCutAction.REPLACE, selectors: [selector], }, ], @@ -292,7 +304,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: facet.address, - action: 0, + action: FacetCutAction.ADD, selectors: [selector], }, ], @@ -305,7 +317,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: facet.address, - action: 1, + action: FacetCutAction.REPLACE, selectors: [selector], }, ], @@ -328,13 +340,17 @@ export function describeBehaviorOfDiamondWritable( ethers.provider, ); - await instance - .connect(owner) - .diamondCut( - [{ target: facet.address, action: 0, selectors }], - ethers.constants.AddressZero, - '0x', - ); + await instance.connect(owner).diamondCut( + [ + { + target: facet.address, + action: FacetCutAction.ADD, + selectors, + }, + ], + ethers.constants.AddressZero, + '0x', + ); for (let fn of functions) { // call reverts, but with mock-specific message @@ -343,13 +359,17 @@ export function describeBehaviorOfDiamondWritable( ); } - await instance - .connect(owner) - .diamondCut( - [{ target: ethers.constants.AddressZero, action: 2, selectors }], - ethers.constants.AddressZero, - '0x', - ); + await instance.connect(owner).diamondCut( + [ + { + target: ethers.constants.AddressZero, + action: FacetCutAction.REMOVE, + selectors, + }, + ], + ethers.constants.AddressZero, + '0x', + ); for (let fn of functions) { await expect( @@ -368,7 +388,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: instance.address, - action: 2, + action: FacetCutAction.REMOVE, selectors: [ethers.utils.randomBytes(4)], }, ], @@ -387,7 +407,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: ethers.constants.AddressZero, - action: 2, + action: FacetCutAction.REMOVE, selectors: [ethers.utils.randomBytes(4)], }, ], @@ -407,7 +427,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: instance.address, - action: 0, + action: FacetCutAction.ADD, selectors: [selector], }, ], @@ -420,7 +440,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: ethers.constants.AddressZero, - action: 2, + action: FacetCutAction.REMOVE, selectors: [selector], }, ], @@ -466,7 +486,7 @@ export function describeBehaviorOfDiamondWritable( [ { target: ethers.constants.AddressZero, - action: 0, + action: FacetCutAction.ADD, selectors: [], }, ], diff --git a/test/proxy/diamond/SolidStateDiamond.ts b/test/proxy/diamond/SolidStateDiamond.ts index 508e8e41..b9766414 100644 --- a/test/proxy/diamond/SolidStateDiamond.ts +++ b/test/proxy/diamond/SolidStateDiamond.ts @@ -1,4 +1,5 @@ import { SignerWithAddress } from '@nomiclabs/hardhat-ethers/signers'; +import { FacetCutAction } from '@solidstate/library'; import { describeBehaviorOfSolidStateDiamond } from '@solidstate/spec'; import { SolidStateDiamond, @@ -30,7 +31,7 @@ describe('SolidStateDiamond', function () { facetCuts[0] = { target: instance.address, - action: 0, + action: FacetCutAction.ADD, selectors: facets[0].selectors, }; }); diff --git a/test/proxy/diamond/base/DiamondBase.ts b/test/proxy/diamond/base/DiamondBase.ts index f317eba6..cf874d68 100644 --- a/test/proxy/diamond/base/DiamondBase.ts +++ b/test/proxy/diamond/base/DiamondBase.ts @@ -1,3 +1,4 @@ +import { FacetCutAction } from '@solidstate/library'; import { describeBehaviorOfDiamondBase } from '@solidstate/spec'; import { DiamondBaseMock, @@ -18,7 +19,7 @@ describe('DiamondBase', function () { instance = await new DiamondBaseMock__factory(deployer).deploy([ { target: facetInstance.address, - action: 0, + action: FacetCutAction.ADD, selectors: [facetInstance.interface.getSighash('owner()')], }, ]); diff --git a/test/proxy/diamond/fallback/DiamondFallback.ts b/test/proxy/diamond/fallback/DiamondFallback.ts index 75404ead..6734f557 100644 --- a/test/proxy/diamond/fallback/DiamondFallback.ts +++ b/test/proxy/diamond/fallback/DiamondFallback.ts @@ -1,4 +1,5 @@ import { SignerWithAddress } from '@nomiclabs/hardhat-ethers/signers'; +import { FacetCutAction } from '@solidstate/library'; import { describeBehaviorOfDiamondFallback } from '@solidstate/spec'; import { DiamondFallbackMock, @@ -25,7 +26,7 @@ describe('DiamondFallback', function () { instance = await new DiamondFallbackMock__factory(deployer).deploy([ { target: facetInstance.address, - action: 0, + action: FacetCutAction.ADD, selectors: [facetInstance.interface.getSighash('owner()')], }, ]); diff --git a/test/proxy/diamond/readable/DiamondReadable.ts b/test/proxy/diamond/readable/DiamondReadable.ts index 6c75969a..3e1eb98c 100644 --- a/test/proxy/diamond/readable/DiamondReadable.ts +++ b/test/proxy/diamond/readable/DiamondReadable.ts @@ -1,4 +1,5 @@ import { deployMockContract } from '@ethereum-waffle/mock-contract'; +import { FacetCutAction } from '@solidstate/library'; import { describeBehaviorOfDiamondReadable } from '@solidstate/spec'; import { DiamondReadableMock, @@ -35,7 +36,7 @@ describe('DiamondReadable', function () { facetCuts.push({ target: facet.address, - action: 0, + action: FacetCutAction.ADD, selectors, }); });