Skip to content

Commit

Permalink
Merge pull request #602 from lidofinance/feat/remove-extra-data-v1
Browse files Browse the repository at this point in the history
Feat: Remove extra data v1
  • Loading branch information
F4ever authored Jan 24, 2025
2 parents dd67e0d + fb2e32e commit 01ec02a
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 443 deletions.
15 changes: 2 additions & 13 deletions src/modules/accounting/accounting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from src import variables
from src.constants import SHARE_RATE_PRECISION_E27
from src.modules.accounting.third_phase.extra_data import ExtraDataService
from src.modules.accounting.third_phase.extra_data_v2 import ExtraDataServiceV2
from src.modules.accounting.third_phase.types import ExtraData, FormatList
from src.modules.accounting.types import (
ReportData,
Expand Down Expand Up @@ -333,24 +332,14 @@ def _is_bunker(self, blockstamp: ReferenceBlockStamp) -> BunkerMode:

@lru_cache(maxsize=1)
def get_extra_data(self, blockstamp: ReferenceBlockStamp) -> ExtraData:
consensus_version = self.w3.lido_contracts.accounting_oracle.get_consensus_version(blockstamp.block_hash)

chain_config = self.get_chain_config(blockstamp)
stuck_validators = self.lido_validator_state_service.get_lido_newly_stuck_validators(blockstamp, chain_config)
logger.info({'msg': 'Calculate stuck validators.', 'value': stuck_validators})
exited_validators = self.lido_validator_state_service.get_lido_newly_exited_validators(blockstamp)
logger.info({'msg': 'Calculate exited validators.', 'value': exited_validators})
orl = self.w3.lido_contracts.oracle_report_sanity_checker.get_oracle_report_limits(blockstamp.block_hash)

if consensus_version == 1:
return ExtraDataService.collect(
stuck_validators,
exited_validators,
orl.max_items_per_extra_data_transaction,
orl.max_node_operators_per_extra_data_item,
)

return ExtraDataServiceV2.collect(
return ExtraDataService.collect(
stuck_validators,
exited_validators,
orl.max_items_per_extra_data_transaction,
Expand Down Expand Up @@ -383,7 +372,7 @@ def _calculate_wq_report(self, blockstamp: ReferenceBlockStamp) -> WqReport:

def _calculate_extra_data_report(self, blockstamp: ReferenceBlockStamp) -> ExtraData:
stuck_validators, exited_validators, orl = self._get_generic_extra_data(blockstamp)
return ExtraDataServiceV2.collect(
return ExtraDataService.collect(
stuck_validators,
exited_validators,
orl.max_items_per_extra_data_transaction,
Expand Down
159 changes: 82 additions & 77 deletions src/modules/accounting/third_phase/extra_data.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,138 @@
import itertools
from dataclasses import dataclass
from itertools import groupby, batched
from typing import Sequence

from hexbytes import HexBytes

from src.modules.accounting.third_phase.types import ItemType, ExtraData, FormatList, ExtraDataLengths
from src.modules.accounting.third_phase.types import ExtraData, ItemType, ExtraDataLengths, FormatList
from src.modules.submodules.types import ZERO_HASH
from src.types import NodeOperatorGlobalIndex
from src.web3py.types import Web3


@dataclass
class ItemPayload:
module_id: bytes
node_ops_count: bytes
node_operator_ids: bytes
vals_counts: bytes


@dataclass
class ExtraDataItem:
item_index: bytes
item_type: ItemType
item_payload: ItemPayload
module_id: int
node_operator_ids: Sequence[int]
vals_counts: Sequence[int]


class ExtraDataService:
"""
Service that encodes extra data into bytes in correct order.
Extra data is an array of items, each item being encoded as follows:
| 3 bytes | 2 bytes | X bytes |
| itemIndex | itemType | itemPayload |
| 32 bytes | 3 bytes | 2 bytes | X bytes |
| nextHash | itemIndex | itemType | itemPayload |
itemPayload format:
| 3 bytes | 8 bytes | nodeOpsCount * 8 bytes | nodeOpsCount * 16 bytes |
| moduleId | nodeOpsCount | nodeOperatorIds | stuckOrExitedValsCount |
max_items_count - max itemIndex in extra data.
max_items_count_per_tx - max itemIndex in extra data.
max_no_in_payload_count - max nodeOpsCount that could be used in itemPayload.
"""
@classmethod
def collect(
cls,
stuck_validators: dict[NodeOperatorGlobalIndex, int],
exited_validators: dict[NodeOperatorGlobalIndex, int],
max_items_count: int,
max_items_count_per_tx: int,
max_no_in_payload_count: int,
) -> ExtraData:
stuck_payloads = cls.build_validators_payloads(stuck_validators, max_no_in_payload_count)
exited_payloads = cls.build_validators_payloads(exited_validators, max_no_in_payload_count)
items_count, txs = cls.build_extra_transactions_data(stuck_payloads, exited_payloads, max_items_count_per_tx)
first_hash, hashed_txs = cls.add_hashes_to_transactions(txs)

extra_data = cls.build_extra_data(stuck_payloads, exited_payloads, max_items_count)
extra_data_bytes = cls.to_bytes(extra_data)

if extra_data:
extra_data_list = [extra_data_bytes]
data_format = FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY
data_hash = Web3.keccak(extra_data_bytes)
if items_count:
extra_data_format = FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY
else:
extra_data_list = []
data_format = FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY
data_hash = HexBytes(ZERO_HASH)
extra_data_format = FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY

return ExtraData(
extra_data_list=extra_data_list,
data_hash=data_hash,
format=data_format.value,
items_count=len(extra_data),
items_count=items_count,
extra_data_list=hashed_txs,
data_hash=first_hash,
format=extra_data_format.value,
)

@staticmethod
@classmethod
def build_validators_payloads(
cls,
validators: dict[NodeOperatorGlobalIndex, int],
max_no_in_payload_count: int,
) -> list[ItemPayload]:
# sort by module id and node operator id
operator_validators = sorted(validators.items(), key=lambda x: x[0])

payloads = []

for module_id, operators_by_module in itertools.groupby(operator_validators, key=lambda x: x[0][0]):
operator_ids = []
vals_count = []

for ((_, no_id), validators_count) in list(operators_by_module)[:max_no_in_payload_count]:
operator_ids.append(no_id.to_bytes(ExtraDataLengths.NODE_OPERATOR_ID))
vals_count.append(validators_count.to_bytes(ExtraDataLengths.STUCK_OR_EXITED_VALS_COUNT))

payloads.append(
ItemPayload(
module_id=module_id.to_bytes(ExtraDataLengths.MODULE_ID),
node_ops_count=len(operator_ids).to_bytes(ExtraDataLengths.NODE_OPS_COUNT),
node_operator_ids=b"".join(operator_ids),
vals_counts=b"".join(vals_count),
for module_id, operators_by_module in groupby(operator_validators, key=lambda x: x[0][0]):
for nos_in_batch in batched(list(operators_by_module), max_no_in_payload_count):
operator_ids = []
vals_count = []

for ((_, no_id), validators_count) in nos_in_batch:
operator_ids.append(no_id)
vals_count.append(validators_count)

payloads.append(
ItemPayload(
module_id=module_id,
node_operator_ids=operator_ids,
vals_counts=vals_count,
)
)
)

return payloads

@staticmethod
def build_extra_data(stuck_payloads: list[ItemPayload], exited_payloads: list[ItemPayload], max_items_count: int):
@classmethod
def build_extra_transactions_data(
cls,
stuck_payloads: list[ItemPayload],
exited_payloads: list[ItemPayload],
max_items_count_per_tx: int,
) -> tuple[int, list[bytes]]:
all_payloads = [
*[(ItemType.EXTRA_DATA_TYPE_STUCK_VALIDATORS, payload) for payload in stuck_payloads],
*[(ItemType.EXTRA_DATA_TYPE_EXITED_VALIDATORS, payload) for payload in exited_payloads],
]

index = 0
extra_data = []

for item_type, payloads in [
(ItemType.EXTRA_DATA_TYPE_STUCK_VALIDATORS, stuck_payloads),
(ItemType.EXTRA_DATA_TYPE_EXITED_VALIDATORS, exited_payloads),
]:
for payload in payloads:
extra_data.append(ExtraDataItem(
item_index=index.to_bytes(ExtraDataLengths.ITEM_INDEX),
item_type=item_type,
item_payload=payload
))
result = []

for payload_batch in batched(all_payloads, max_items_count_per_tx):
tx_body = b''
for item_type, payload in payload_batch:
tx_body += index.to_bytes(ExtraDataLengths.ITEM_INDEX)
tx_body += item_type.value.to_bytes(ExtraDataLengths.ITEM_TYPE)
tx_body += payload.module_id.to_bytes(ExtraDataLengths.MODULE_ID)
tx_body += len(payload.node_operator_ids).to_bytes(ExtraDataLengths.NODE_OPS_COUNT)
tx_body += b''.join(
no_id.to_bytes(ExtraDataLengths.NODE_OPERATOR_ID)
for no_id in payload.node_operator_ids
)
tx_body += b''.join(
count.to_bytes(ExtraDataLengths.STUCK_OR_EXITED_VALS_COUNT)
for count in payload.vals_counts
)

index += 1
if index == max_items_count:
return extra_data

return extra_data
result.append(tx_body)

return index, result

@staticmethod
def to_bytes(extra_data: list[ExtraDataItem]) -> bytes:
extra_data_bytes = b''
for item in extra_data:
extra_data_bytes += item.item_index
extra_data_bytes += item.item_type.value.to_bytes(ExtraDataLengths.ITEM_TYPE)
extra_data_bytes += item.item_payload.module_id
extra_data_bytes += item.item_payload.node_ops_count
extra_data_bytes += item.item_payload.node_operator_ids
extra_data_bytes += item.item_payload.vals_counts
return extra_data_bytes
def add_hashes_to_transactions(txs_data: list[bytes]) -> tuple[bytes, list[bytes]]:
txs_data.reverse()

txs_with_hashes = []
next_hash = ZERO_HASH

for tx in txs_data:
full_tx_data = next_hash + tx
txs_with_hashes.append(full_tx_data)
next_hash = Web3.keccak(full_tx_data)

txs_with_hashes.reverse()

return next_hash, txs_with_hashes
138 changes: 0 additions & 138 deletions src/modules/accounting/third_phase/extra_data_v2.py

This file was deleted.

Loading

0 comments on commit 01ec02a

Please sign in to comment.