Skip to content

Commit

Permalink
Reader updates and backward compatibility (#360)
Browse files Browse the repository at this point in the history
* some error handling

* A few fields for backward compatibility.

* a few datapack api fix

* pylint.

* Add data.

* fix mypy

Co-authored-by: hector.liu <hector.liu@petuum.com>
  • Loading branch information
hunterhector and hector.liu authored Jan 5, 2021
1 parent b6986d2 commit 258467a
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 42 deletions.
30 changes: 16 additions & 14 deletions forte/data/base_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ def __getstate__(self):

def __setstate__(self, state):
super().__setstate__(state)
self.__dict__['_pending_entries'] = {}
self.__control_component: Optional[str] = None
if 'meta' in self.__dict__:
self._meta = self.__dict__.pop('meta')
self.__control_component = None
self._pending_entries = {}

@abstractmethod
def _init_meta(self, pack_name: Optional[str] = None) -> BaseMeta:
Expand Down Expand Up @@ -220,8 +222,8 @@ def add_all_remaining_entries(self, component: Optional[str] = None):
def serialize(self, drop_record: Optional[bool] = False) -> str:
r"""Serializes a pack to a string."""
if drop_record:
self.creation_records.clear()
self.field_records.clear()
self._creation_records.clear()
self._field_records.clear()

return jsonpickle.encode(self, unpicklable=True)

Expand Down Expand Up @@ -249,9 +251,9 @@ def record_entry(self, entry: Entry, component_name: Optional[str] = None):

if c is not None:
try:
self.creation_records[c].add(entry.tid)
self._creation_records[c].add(entry.tid)
except KeyError:
self.creation_records[c] = {entry.tid}
self._creation_records[c] = {entry.tid}

def record_field(self, entry_id: int, field_name: str):
"""
Expand All @@ -269,9 +271,9 @@ def record_field(self, entry_id: int, field_name: str):

if c is not None:
try:
self.field_records[c].add((entry_id, field_name))
self._field_records[c].add((entry_id, field_name))
except KeyError:
self.field_records[c] = {(entry_id, field_name)}
self._field_records[c] = {(entry_id, field_name)}

def on_entry_creation(self, entry: Entry,
component_name: Optional[str] = None):
Expand Down Expand Up @@ -346,12 +348,12 @@ def get_single(self, entry_type: Type[EntryType]) -> EntryType:
raise EntryNotFoundError(
f"The entry {entry_type} is not found in the provided pack.")

def get_ids_by_component(self, component: str) -> Set[int]:
def get_ids_by_creator(self, component: str) -> Set[int]:
r"""Look up the component_index with key ``component``."""
entry_set: Set[int] = self.creation_records[component]
entry_set: Set[int] = self._creation_records[component]
return entry_set

def get_entries_by_component(self, component: str) -> Set[EntryType]:
def get_entries_by_creator(self, component: str) -> Set[EntryType]:
"""
Return all entries created by the particular component, an unordered
set.
Expand All @@ -363,13 +365,13 @@ def get_entries_by_component(self, component: str) -> Set[EntryType]:
"""
return {self.get_entry(tid)
for tid in self.get_ids_by_component(component)}
for tid in self.get_ids_by_creator(component)}

def get_ids_by_components(self, components: List[str]) -> Set[int]:
def get_ids_by_creators(self, components: List[str]) -> Set[int]:
"""Look up component_index using a list of components."""
valid_component_id: Set[int] = set()
for component in components:
valid_component_id |= self.get_ids_by_component(component)
valid_component_id |= self.get_ids_by_creator(component)
return valid_component_id

def get_ids_by_type(self, entry_type: Type[EntryType]) -> Set[int]:
Expand Down
21 changes: 16 additions & 5 deletions forte/data/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# pylint: disable=function-redefined,multiple-statements

from abc import abstractmethod
from typing import Dict, Generic, Set, Tuple, TypeVar
from typing import Dict, Generic, Set, Tuple, TypeVar, Iterator

from forte.data.span import Span

Expand Down Expand Up @@ -66,11 +66,11 @@ def current_id_counter(self) -> int:
class EntryContainer(Generic[E, L, G]):
def __init__(self):
# Record the set of entries created by some components.
self.creation_records: Dict[str, Set[int]] = {}
self._creation_records: Dict[str, Set[int]] = {}

# Record the set of fields modified by this component. The 2-tuple
# identify the entry field, such as (2, lemma).
self.field_records: Dict[str, Set[Tuple[int, str]]] = {}
self._field_records: Dict[str, Set[Tuple[int, str]]] = {}

# The Id manager controls the ID management in this container
self._id_manager = EntryIdManager()
Expand All @@ -93,8 +93,16 @@ def __setstate__(self, state):
- The :class:`IdManager` is recreated from the id count.
"""
self.__dict__.update(state)
self.__dict__.pop('serialization')
self._id_manager = EntryIdManager(state['serialization']['next_id'])

if 'creation_records' in self.__dict__:
self._creation_records = self.__dict__.pop('creation_records')

if 'field_records' in self.__dict__:
self._field_records = self.__dict__.pop('field_records')

if 'serialization' in self.__dict__:
self._id_manager = EntryIdManager(
self.__dict__.pop('serialization')['next_id'])

@abstractmethod
def on_entry_creation(self, entry: E):
Expand Down Expand Up @@ -130,5 +138,8 @@ def get_span_text(self, span: Span):
def get_next_id(self):
return self._id_manager.get_id()

def get_all_creator(self) -> Iterator[str]:
yield from self._creation_records.keys()


ContainerType = TypeVar("ContainerType", bound=EntryContainer)
25 changes: 18 additions & 7 deletions forte/data/data_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def __setstate__(self, state):
"""
super().__setstate__(state)

# For backward compatibility.
if 'replace_back_operations' in self.__dict__:
self.__replace_back_operations = self.__dict__.pop(
'replace_back_operations')
if 'processed_original_spans' in self.__dict__:
self.__processed_original_spans = self.__dict__.pop(
'processed_original_spans')
if 'orig_text_len' in self.__dict__:
self.__orig_text_len = self.__dict__.pop('orig_text_len')

self.annotations = SortedList(self.annotations)
self.links = SortedList(self.links)
self.groups = SortedList(self.groups)
Expand Down Expand Up @@ -632,7 +642,7 @@ def get_data(self, context_type: Type[Annotation],
if context_components:
valid_component_id: Set[int] = set()
for component in context_components:
valid_component_id |= self.get_ids_by_component(component)
valid_component_id |= self.get_ids_by_creator(component)
valid_context_ids &= valid_component_id

skipped = 0
Expand Down Expand Up @@ -855,7 +865,8 @@ def get(self, entry_type: Type[EntryType], # type: ignore
range_annotation: Optional[Annotation] = None,
components: Optional[Union[str, List[str]]] = None
) -> Iterable[EntryType]:
r"""This is a shorthand alias to :func:`get_entries`
r"""This function is used to get data from a data pack with various
methods.
Example:
Expand All @@ -875,9 +886,9 @@ def get(self, entry_type: Type[EntryType], # type: ignore
range_annotation (Annotation, optional): The range of entries
requested. If `None`, will return valid entries in the range of
whole data_pack.
components (str or list, optional): The component generating the
entries requested. If `None`, will return valid entries
generated by any component.
components (str or list, optional): The component (creator)
generating the entries requested. If `None`, will return valid
entries generated by any component.
"""
# If we don't have any annotations, then we yield an empty list.
# Note that generics do not work with annotations.
Expand All @@ -891,7 +902,7 @@ def get(self, entry_type: Type[EntryType], # type: ignore
if components is not None:
if isinstance(components, str):
components = [components]
valid_id &= self.get_ids_by_components(components)
valid_id &= self.get_ids_by_creators(components)

# Generics do not work with range_annotation.
if issubclass(entry_type, Generics):
Expand Down Expand Up @@ -968,7 +979,7 @@ class DataIndex(BaseIndex):
def __init__(self):
super().__init__()
self._coverage_index: Dict[Tuple[Type[Annotation], Type[EntryType]],
Dict[int, Set[int]]] = dict()
Dict[int, Set[int]]] = dict()
self._coverage_index_valid = True

@property
Expand Down
2 changes: 1 addition & 1 deletion forte/data/multi_pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def get(self, entry_type: Type[EntryType], # type: ignore
if components is not None:
if isinstance(components, str):
components = [components]
valid_id &= self.get_ids_by_components(components)
valid_id &= self.get_ids_by_creators(components)

for entry_id in valid_id:
yield self.get_entry(entry_id)
Expand Down
45 changes: 40 additions & 5 deletions forte/data/ontology/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import numpy as np

from forte.common.exception import IncompleteEntryError
from forte.data.base_pack import PackType
from forte.data.ontology.core import Entry, BaseLink, BaseGroup, MultiEntry
from forte.data.span import Span
Expand Down Expand Up @@ -282,15 +281,51 @@ def __init__(
@property
def parent(self) -> Tuple[int, int]:
if self._parent is None:
raise IncompleteEntryError("Parent is not set for this link.")
raise ValueError("Parent is not set for this link.")
return self._parent

@property
def child(self) -> Tuple[int, int]:
if self._child is None:
raise IncompleteEntryError("Child is not set for this link.")
raise ValueError("Child is not set for this link.")
return self._child

def parent_id(self) -> int:
"""
Return the `tid` of the parent entry.
Returns: The `tid` of the parent entry.
"""
return self.parent[1]

def child_id(self) -> int:
"""
Return the `tid` of the child entry.
Returns: The `tid` of the child entry.
"""
return self.child[1]

def parent_pack_id(self) -> int:
"""
Return the `pack_id` of the parent pack.
Returns: The `pack_id` of the parent pack..
"""
if self._parent is None:
raise ValueError("Parent is not set for this link.")
return self.pack.packs[self._parent[0]].pack_id

def child_pack_id(self) -> int:
"""
Return the `pack_id` of the child pack.
Returns: The `pack_id` of the child pack.
"""
if self._child is None:
raise ValueError("Child is not set for this link.")
return self.pack.packs[self._child[0]].pack_id

def set_parent(self, parent: Entry):
r"""This will set the `parent` of the current instance with given Entry.
The parent is saved internally as a tuple: ``pack index`` and
Expand Down Expand Up @@ -331,7 +366,7 @@ def get_parent(self) -> Entry:
An instance of :class:`Entry` that is the parent of the link.
"""
if self._parent is None:
raise IncompleteEntryError("The parent of this link is not set.")
raise ValueError("The parent of this link is not set.")

pack_idx, parent_tid = self._parent
return self.pack.get_subentry(pack_idx, parent_tid)
Expand All @@ -343,7 +378,7 @@ def get_child(self) -> Entry:
An instance of :class:`Entry` that is the child of the link.
"""
if self._child is None:
raise IncompleteEntryError("The parent of this link is not set.")
raise ValueError("The parent of this link is not set.")

pack_idx, child_tid = self._child
return self.pack.get_subentry(pack_idx, child_tid)
Expand Down
26 changes: 19 additions & 7 deletions forte/data/readers/deserialize_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from abc import ABC, abstractmethod

from typing import Iterator, List, Any, Union
from typing import Iterator, List, Any, Union, Optional

from forte.common.exception import ProcessExecutionException
from forte.data.data_pack import DataPack
Expand Down Expand Up @@ -126,14 +127,21 @@ def _parse_pack(self, multi_pack_str: str) -> Iterator[MultiPack]:

for pid in m_pack.pack_ids():
p_content = self._get_pack_content(pid)
if p_content is None:
logging.warning(
"Cannot locate the data pack with pid %d "
"for multi pack %d", pid, m_pack.pack_id)
break
pack: DataPack
if isinstance(p_content, str):
pack = DataPack.deserialize(p_content)
else:
pack = p_content
# Only in deserialization we can do this.
m_pack.packs.append(pack)
yield m_pack
else:
# No multi pack will be yield if there are packs not located.
yield m_pack

@abstractmethod
def _get_multipack_content(self, *args: Any, **kwargs: Any
Expand All @@ -148,7 +156,7 @@ def _get_multipack_content(self, *args: Any, **kwargs: Any
raise NotImplementedError

@abstractmethod
def _get_pack_content(self, pack_id: int) -> Union[str, DataPack]:
def _get_pack_content(self, pack_id: int) -> Union[None, str, DataPack]:
"""
Implementation of this method should be responsible for returning the
raw string of the data pack from the pack id.
Expand Down Expand Up @@ -182,10 +190,14 @@ def _get_multipack_content(self) -> Iterator[str]: # type: ignore
self.configs.multi_pack_dir, f)) as m_data:
yield m_data.read()

def _get_pack_content(self, pack_id: int) -> str:
with open(os.path.join(
self.configs.data_pack_dir, f'{pack_id}.json')) as pack_data:
return pack_data.read()
def _get_pack_content(self, pack_id: int) -> Optional[str]:
pack_path = os.path.join(
self.configs.data_pack_dir, f'{pack_id}.json')
if os.path.exists(pack_path):
with open(pack_path) as pack_data:
return pack_data.read()
else:
return None

@classmethod
def default_configs(cls):
Expand Down
6 changes: 3 additions & 3 deletions forte/data/readers/stave_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import sqlite3
from typing import Iterator, Dict
from typing import Iterator, Dict, Optional

from forte.common import Resources, ProcessorConfigError
from forte.common.configuration import Config
Expand Down Expand Up @@ -85,8 +85,8 @@ def _get_multipack_content(self) -> Iterator[str]: # type: ignore
f'SELECT textPack FROM {self.configs.multipack_table}'):
yield value[0]

def _get_pack_content(self, pack_id: int) -> DataPack:
return self.data_packs[pack_id]
def _get_pack_content(self, pack_id: int) -> Optional[DataPack]:
return self.data_packs.get(pack_id, None)

@classmethod
def default_configs(cls):
Expand Down

0 comments on commit 258467a

Please sign in to comment.