From 258467a43e2f21726dd1842c4b740db3e234d5aa Mon Sep 17 00:00:00 2001 From: Hector Date: Tue, 5 Jan 2021 00:39:23 -0500 Subject: [PATCH] Reader updates and backward compatibility (#360) * some error handling * A few fields for backward compatibility. * a few datapack api fix * pylint. * Add data. * fix mypy Co-authored-by: hector.liu --- forte/data/base_pack.py | 30 ++++++++-------- forte/data/container.py | 21 ++++++++--- forte/data/data_pack.py | 25 +++++++++---- forte/data/multi_pack.py | 2 +- forte/data/ontology/top.py | 45 +++++++++++++++++++++--- forte/data/readers/deserialize_reader.py | 26 ++++++++++---- forte/data/readers/stave_readers.py | 6 ++-- 7 files changed, 113 insertions(+), 42 deletions(-) diff --git a/forte/data/base_pack.py b/forte/data/base_pack.py index 7367e5ff8..297046324 100644 --- a/forte/data/base_pack.py +++ b/forte/data/base_pack.py @@ -97,8 +97,10 @@ def __getstate__(self): def __setstate__(self, state): super().__setstate__(state) - self.__dict__['_pending_entries'] = {} - self.__control_component: Optional[str] = None + if 'meta' in self.__dict__: + self._meta = self.__dict__.pop('meta') + self.__control_component = None + self._pending_entries = {} @abstractmethod def _init_meta(self, pack_name: Optional[str] = None) -> BaseMeta: @@ -220,8 +222,8 @@ def add_all_remaining_entries(self, component: Optional[str] = None): def serialize(self, drop_record: Optional[bool] = False) -> str: r"""Serializes a pack to a string.""" if drop_record: - self.creation_records.clear() - self.field_records.clear() + self._creation_records.clear() + self._field_records.clear() return jsonpickle.encode(self, unpicklable=True) @@ -249,9 +251,9 @@ def record_entry(self, entry: Entry, component_name: Optional[str] = None): if c is not None: try: - self.creation_records[c].add(entry.tid) + self._creation_records[c].add(entry.tid) except KeyError: - self.creation_records[c] = {entry.tid} + self._creation_records[c] = {entry.tid} def record_field(self, entry_id: int, field_name: str): """ @@ -269,9 +271,9 @@ def record_field(self, entry_id: int, field_name: str): if c is not None: try: - self.field_records[c].add((entry_id, field_name)) + self._field_records[c].add((entry_id, field_name)) except KeyError: - self.field_records[c] = {(entry_id, field_name)} + self._field_records[c] = {(entry_id, field_name)} def on_entry_creation(self, entry: Entry, component_name: Optional[str] = None): @@ -346,12 +348,12 @@ def get_single(self, entry_type: Type[EntryType]) -> EntryType: raise EntryNotFoundError( f"The entry {entry_type} is not found in the provided pack.") - def get_ids_by_component(self, component: str) -> Set[int]: + def get_ids_by_creator(self, component: str) -> Set[int]: r"""Look up the component_index with key ``component``.""" - entry_set: Set[int] = self.creation_records[component] + entry_set: Set[int] = self._creation_records[component] return entry_set - def get_entries_by_component(self, component: str) -> Set[EntryType]: + def get_entries_by_creator(self, component: str) -> Set[EntryType]: """ Return all entries created by the particular component, an unordered set. @@ -363,13 +365,13 @@ def get_entries_by_component(self, component: str) -> Set[EntryType]: """ return {self.get_entry(tid) - for tid in self.get_ids_by_component(component)} + for tid in self.get_ids_by_creator(component)} - def get_ids_by_components(self, components: List[str]) -> Set[int]: + def get_ids_by_creators(self, components: List[str]) -> Set[int]: """Look up component_index using a list of components.""" valid_component_id: Set[int] = set() for component in components: - valid_component_id |= self.get_ids_by_component(component) + valid_component_id |= self.get_ids_by_creator(component) return valid_component_id def get_ids_by_type(self, entry_type: Type[EntryType]) -> Set[int]: diff --git a/forte/data/container.py b/forte/data/container.py index 4b6806816..df2e08f0b 100644 --- a/forte/data/container.py +++ b/forte/data/container.py @@ -19,7 +19,7 @@ # pylint: disable=function-redefined,multiple-statements from abc import abstractmethod -from typing import Dict, Generic, Set, Tuple, TypeVar +from typing import Dict, Generic, Set, Tuple, TypeVar, Iterator from forte.data.span import Span @@ -66,11 +66,11 @@ def current_id_counter(self) -> int: class EntryContainer(Generic[E, L, G]): def __init__(self): # Record the set of entries created by some components. - self.creation_records: Dict[str, Set[int]] = {} + self._creation_records: Dict[str, Set[int]] = {} # Record the set of fields modified by this component. The 2-tuple # identify the entry field, such as (2, lemma). - self.field_records: Dict[str, Set[Tuple[int, str]]] = {} + self._field_records: Dict[str, Set[Tuple[int, str]]] = {} # The Id manager controls the ID management in this container self._id_manager = EntryIdManager() @@ -93,8 +93,16 @@ def __setstate__(self, state): - The :class:`IdManager` is recreated from the id count. """ self.__dict__.update(state) - self.__dict__.pop('serialization') - self._id_manager = EntryIdManager(state['serialization']['next_id']) + + if 'creation_records' in self.__dict__: + self._creation_records = self.__dict__.pop('creation_records') + + if 'field_records' in self.__dict__: + self._field_records = self.__dict__.pop('field_records') + + if 'serialization' in self.__dict__: + self._id_manager = EntryIdManager( + self.__dict__.pop('serialization')['next_id']) @abstractmethod def on_entry_creation(self, entry: E): @@ -130,5 +138,8 @@ def get_span_text(self, span: Span): def get_next_id(self): return self._id_manager.get_id() + def get_all_creator(self) -> Iterator[str]: + yield from self._creation_records.keys() + ContainerType = TypeVar("ContainerType", bound=EntryContainer) diff --git a/forte/data/data_pack.py b/forte/data/data_pack.py index bc62a115f..e33d24939 100644 --- a/forte/data/data_pack.py +++ b/forte/data/data_pack.py @@ -105,6 +105,16 @@ def __setstate__(self, state): """ super().__setstate__(state) + # For backward compatibility. + if 'replace_back_operations' in self.__dict__: + self.__replace_back_operations = self.__dict__.pop( + 'replace_back_operations') + if 'processed_original_spans' in self.__dict__: + self.__processed_original_spans = self.__dict__.pop( + 'processed_original_spans') + if 'orig_text_len' in self.__dict__: + self.__orig_text_len = self.__dict__.pop('orig_text_len') + self.annotations = SortedList(self.annotations) self.links = SortedList(self.links) self.groups = SortedList(self.groups) @@ -632,7 +642,7 @@ def get_data(self, context_type: Type[Annotation], if context_components: valid_component_id: Set[int] = set() for component in context_components: - valid_component_id |= self.get_ids_by_component(component) + valid_component_id |= self.get_ids_by_creator(component) valid_context_ids &= valid_component_id skipped = 0 @@ -855,7 +865,8 @@ def get(self, entry_type: Type[EntryType], # type: ignore range_annotation: Optional[Annotation] = None, components: Optional[Union[str, List[str]]] = None ) -> Iterable[EntryType]: - r"""This is a shorthand alias to :func:`get_entries` + r"""This function is used to get data from a data pack with various + methods. Example: @@ -875,9 +886,9 @@ def get(self, entry_type: Type[EntryType], # type: ignore range_annotation (Annotation, optional): The range of entries requested. If `None`, will return valid entries in the range of whole data_pack. - components (str or list, optional): The component generating the - entries requested. If `None`, will return valid entries - generated by any component. + components (str or list, optional): The component (creator) + generating the entries requested. If `None`, will return valid + entries generated by any component. """ # If we don't have any annotations, then we yield an empty list. # Note that generics do not work with annotations. @@ -891,7 +902,7 @@ def get(self, entry_type: Type[EntryType], # type: ignore if components is not None: if isinstance(components, str): components = [components] - valid_id &= self.get_ids_by_components(components) + valid_id &= self.get_ids_by_creators(components) # Generics do not work with range_annotation. if issubclass(entry_type, Generics): @@ -968,7 +979,7 @@ class DataIndex(BaseIndex): def __init__(self): super().__init__() self._coverage_index: Dict[Tuple[Type[Annotation], Type[EntryType]], - Dict[int, Set[int]]] = dict() + Dict[int, Set[int]]] = dict() self._coverage_index_valid = True @property diff --git a/forte/data/multi_pack.py b/forte/data/multi_pack.py index 7386b6acd..b14a2aff8 100644 --- a/forte/data/multi_pack.py +++ b/forte/data/multi_pack.py @@ -534,7 +534,7 @@ def get(self, entry_type: Type[EntryType], # type: ignore if components is not None: if isinstance(components, str): components = [components] - valid_id &= self.get_ids_by_components(components) + valid_id &= self.get_ids_by_creators(components) for entry_id in valid_id: yield self.get_entry(entry_id) diff --git a/forte/data/ontology/top.py b/forte/data/ontology/top.py index 89d5ff7a7..220184d74 100644 --- a/forte/data/ontology/top.py +++ b/forte/data/ontology/top.py @@ -17,7 +17,6 @@ import numpy as np -from forte.common.exception import IncompleteEntryError from forte.data.base_pack import PackType from forte.data.ontology.core import Entry, BaseLink, BaseGroup, MultiEntry from forte.data.span import Span @@ -282,15 +281,51 @@ def __init__( @property def parent(self) -> Tuple[int, int]: if self._parent is None: - raise IncompleteEntryError("Parent is not set for this link.") + raise ValueError("Parent is not set for this link.") return self._parent @property def child(self) -> Tuple[int, int]: if self._child is None: - raise IncompleteEntryError("Child is not set for this link.") + raise ValueError("Child is not set for this link.") return self._child + def parent_id(self) -> int: + """ + Return the `tid` of the parent entry. + + Returns: The `tid` of the parent entry. + """ + return self.parent[1] + + def child_id(self) -> int: + """ + Return the `tid` of the child entry. + + Returns: The `tid` of the child entry. + """ + return self.child[1] + + def parent_pack_id(self) -> int: + """ + Return the `pack_id` of the parent pack. + + Returns: The `pack_id` of the parent pack.. + """ + if self._parent is None: + raise ValueError("Parent is not set for this link.") + return self.pack.packs[self._parent[0]].pack_id + + def child_pack_id(self) -> int: + """ + Return the `pack_id` of the child pack. + + Returns: The `pack_id` of the child pack. + """ + if self._child is None: + raise ValueError("Child is not set for this link.") + return self.pack.packs[self._child[0]].pack_id + def set_parent(self, parent: Entry): r"""This will set the `parent` of the current instance with given Entry. The parent is saved internally as a tuple: ``pack index`` and @@ -331,7 +366,7 @@ def get_parent(self) -> Entry: An instance of :class:`Entry` that is the parent of the link. """ if self._parent is None: - raise IncompleteEntryError("The parent of this link is not set.") + raise ValueError("The parent of this link is not set.") pack_idx, parent_tid = self._parent return self.pack.get_subentry(pack_idx, parent_tid) @@ -343,7 +378,7 @@ def get_child(self) -> Entry: An instance of :class:`Entry` that is the child of the link. """ if self._child is None: - raise IncompleteEntryError("The parent of this link is not set.") + raise ValueError("The parent of this link is not set.") pack_idx, child_tid = self._child return self.pack.get_subentry(pack_idx, child_tid) diff --git a/forte/data/readers/deserialize_reader.py b/forte/data/readers/deserialize_reader.py index 1a350524c..dd3ce9b36 100644 --- a/forte/data/readers/deserialize_reader.py +++ b/forte/data/readers/deserialize_reader.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from abc import ABC, abstractmethod -from typing import Iterator, List, Any, Union +from typing import Iterator, List, Any, Union, Optional from forte.common.exception import ProcessExecutionException from forte.data.data_pack import DataPack @@ -126,6 +127,11 @@ def _parse_pack(self, multi_pack_str: str) -> Iterator[MultiPack]: for pid in m_pack.pack_ids(): p_content = self._get_pack_content(pid) + if p_content is None: + logging.warning( + "Cannot locate the data pack with pid %d " + "for multi pack %d", pid, m_pack.pack_id) + break pack: DataPack if isinstance(p_content, str): pack = DataPack.deserialize(p_content) @@ -133,7 +139,9 @@ def _parse_pack(self, multi_pack_str: str) -> Iterator[MultiPack]: pack = p_content # Only in deserialization we can do this. m_pack.packs.append(pack) - yield m_pack + else: + # No multi pack will be yield if there are packs not located. + yield m_pack @abstractmethod def _get_multipack_content(self, *args: Any, **kwargs: Any @@ -148,7 +156,7 @@ def _get_multipack_content(self, *args: Any, **kwargs: Any raise NotImplementedError @abstractmethod - def _get_pack_content(self, pack_id: int) -> Union[str, DataPack]: + def _get_pack_content(self, pack_id: int) -> Union[None, str, DataPack]: """ Implementation of this method should be responsible for returning the raw string of the data pack from the pack id. @@ -182,10 +190,14 @@ def _get_multipack_content(self) -> Iterator[str]: # type: ignore self.configs.multi_pack_dir, f)) as m_data: yield m_data.read() - def _get_pack_content(self, pack_id: int) -> str: - with open(os.path.join( - self.configs.data_pack_dir, f'{pack_id}.json')) as pack_data: - return pack_data.read() + def _get_pack_content(self, pack_id: int) -> Optional[str]: + pack_path = os.path.join( + self.configs.data_pack_dir, f'{pack_id}.json') + if os.path.exists(pack_path): + with open(pack_path) as pack_data: + return pack_data.read() + else: + return None @classmethod def default_configs(cls): diff --git a/forte/data/readers/stave_readers.py b/forte/data/readers/stave_readers.py index 5af523a5f..c7ac18d30 100644 --- a/forte/data/readers/stave_readers.py +++ b/forte/data/readers/stave_readers.py @@ -19,7 +19,7 @@ """ import sqlite3 -from typing import Iterator, Dict +from typing import Iterator, Dict, Optional from forte.common import Resources, ProcessorConfigError from forte.common.configuration import Config @@ -85,8 +85,8 @@ def _get_multipack_content(self) -> Iterator[str]: # type: ignore f'SELECT textPack FROM {self.configs.multipack_table}'): yield value[0] - def _get_pack_content(self, pack_id: int) -> DataPack: - return self.data_packs[pack_id] + def _get_pack_content(self, pack_id: int) -> Optional[DataPack]: + return self.data_packs.get(pack_id, None) @classmethod def default_configs(cls):