Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultiAnnotator class #515

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ You may find these examples useful as references:
- `A dataset with a remote index <https://github.com/mir-dataset-loaders/mirdata/blob/master/mirdata/datasets/acousticbrainz_genre.py>`_
- `A dataset with extra dependencies <https://github.com/mir-dataset-loaders/mirdata/blob/master/mirdata/datasets/dali.py>`_
- `A dataset which has multitracks <https://github.com/mir-dataset-loaders/mirdata/blob/master/mirdata/datasets/phenicx_anechoic.py>`_
- `A dataset which has multiple annotators <https://github.com/mir-dataset-loaders/mirdata/blob/master/mirdata/datasets/salami.py>`_

For many more examples, see the `datasets folder <https://github.com/mir-dataset-loaders/mirdata/tree/master/mirdata/datasets>`_.

Expand Down
41 changes: 34 additions & 7 deletions mirdata/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def __repr__(self):
return repr_str


class MultiAnnotator(object):
"""Multiple annotator class.
This class should be used for datasets with multiple annotators (e.g. multiple annotators per track).

Attributes:
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved
annotators (list): list with annotator ids
annotations (list): list of annotations (e.g. [beat_data1, beat_data2] each with
type BeatData or [chord_data1, chord_data2] each with type chord data

"""

def __init__(self, annotators, annotations, dtype) -> None:
validate_array_like(annotators, list, str, none_allowed=True)
validate_array_like(annotations, list, dtype, none_allowed=True)
validate_lengths_equal([annotators, annotations])

self.annotators = annotators
self.annotations = annotations
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved


class BeatData(Annotation):
"""BeatData class

Expand Down Expand Up @@ -376,8 +396,9 @@ def to_sparse_index(
]
)

return np.array(index), convert_amplitude_units(
voicing, self.voicing_unit, amplitude_unit
return (
np.array(index),
convert_amplitude_units(voicing, self.voicing_unit, amplitude_unit),
)

def to_matrix(
Expand Down Expand Up @@ -708,8 +729,9 @@ def to_sparse_index(
if t != -1 and f != -1
]
)
return np.array(index), convert_amplitude_units(
confidence_out, conf_unit, amplitude_unit
return (
np.array(index),
convert_amplitude_units(confidence_out, conf_unit, amplitude_unit),
)

def to_matrix(
Expand Down Expand Up @@ -1383,7 +1405,9 @@ def closest_index(input_array, target_array):
return indexes


def validate_array_like(array_like, expected_type, expected_dtype, none_allowed=False):
def validate_array_like(
array_like, expected_type, expected_dtype, check_child=False, none_allowed=False
):
"""Validate that array-like object is well formed

If array_like is None, validation passes automatically.
Expand All @@ -1392,11 +1416,12 @@ def validate_array_like(array_like, expected_type, expected_dtype, none_allowed=
array_like (array-like): object to validate
expected_type (type): expected type, either list or np.ndarray
expected_dtype (type): expected dtype
check_child (bool): if True, checks if all elements of array are children of expected_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if there are multiple children with different expected dtypes (e.g a beat annotation with ints for the beat positions and floats for the time stamps?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we're only checking if the list is composed of BeatData, ChordData, etc. The checks within the annotation type are done for each annotation type independently

none_allowed (bool): if True, allows array to be None

Raises:
TypeError: if type/dtype does not match expected_type/expected_dtype
ValueError: if array
ValueError: if array is empty but it shouldn't be

"""
if array_like is None:
Expand All @@ -1416,7 +1441,9 @@ def validate_array_like(array_like, expected_type, expected_dtype, none_allowed=
)

if expected_type == list and not all(
isinstance(n, expected_dtype) for n in array_like
isinstance(n, expected_dtype)
for n in array_like
if not ((n is None) and none_allowed)
):
raise TypeError(f"List elements should all have type {expected_dtype}")

Expand Down
67 changes: 60 additions & 7 deletions mirdata/datasets/salami.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import librosa
import numpy as np
import logging

from mirdata import annotations, core, download_utils, io, jams_utils

Expand Down Expand Up @@ -92,6 +93,8 @@ class Track(core.Track):
sections_annotator_1_lowercase (SectionData): annotations in hierarchy level 1 from annotator 1
sections_annotator_2_uppercase (SectionData): annotations in hierarchy level 0 from annotator 2
sections_annotator_2_lowercase (SectionData): annotations in hierarchy level 1 from annotator 2
sections_uppercase (annotations.MultiAnnotator): annotations in hierarchy level 0
sections_lowercase (annotations.MultiAnnotator): annotations in hierarchy level 1
"""

def __init__(
Expand Down Expand Up @@ -157,20 +160,68 @@ def broad_genre(self):
def genre(self):
return self._track_metadata.get("genre")

@core.cached_property
def sections_uppercase(self) -> Optional[annotations.MultiAnnotator]:
magdalenafuentes marked this conversation as resolved.
Show resolved Hide resolved
return annotations.MultiAnnotator(
[
self._track_metadata.get("annotator_1_id"),
self._track_metadata.get("annotator_2_id"),
],
[
load_sections(self.sections_annotator1_uppercase_path),
load_sections(self.sections_annotator2_uppercase_path),
],
annotations.SectionData,
)

@core.cached_property
def sections_lowercase(self) -> Optional[annotations.MultiAnnotator]:
return annotations.MultiAnnotator(
[
self._track_metadata.get("annotator_1_id"),
self._track_metadata.get("annotator_2_id"),
],
[
load_sections(self.sections_annotator1_lowercase_path),
load_sections(self.sections_annotator2_lowercase_path),
],
annotations.SectionData,
)

@core.cached_property
def sections_annotator_1_uppercase(self) -> Optional[annotations.SectionData]:
logging.warning(
"Deprecation warning: sections_anntotator_1_uppercase is deprecated starting "
"in version 0.3.4b3 and will be removed in a future version. "
"Use sections_uppercase in the future."
)
return load_sections(self.sections_annotator1_uppercase_path)

@core.cached_property
def sections_annotator_1_lowercase(self) -> Optional[annotations.SectionData]:
logging.warning(
"Deprecation warning: sections_annotator_1_lowercase is deprecated starting "
"in version 0.3.4b3 and will be removed in a future version. "
"Use sections_lowercase in the future."
)
return load_sections(self.sections_annotator1_lowercase_path)

@core.cached_property
def sections_annotator_2_uppercase(self) -> Optional[annotations.SectionData]:
logging.warning(
"Deprecation warning: sections_anntotator_2_uppercase is deprecated starting "
"in version 0.3.4b3 and will be removed in a future version. "
"Use sections_uppercase in the future."
)
return load_sections(self.sections_annotator2_uppercase_path)

@core.cached_property
def sections_annotator_2_lowercase(self) -> Optional[annotations.SectionData]:
logging.warning(
"Deprecation warning: sections_annotator_2_lowercase is deprecated starting "
"in version 0.3.4b3 and will be removed in a future version. "
"Use sections_lowercase in the future."
)
return load_sections(self.sections_annotator2_lowercase_path)

@property
Expand All @@ -196,17 +247,17 @@ def to_jams(self):
multi_section_data=[
(
[
(self.sections_annotator_1_uppercase, 0),
(self.sections_annotator_1_lowercase, 1),
(self.sections_uppercase.annotations[0], 0),
(self.sections_lowercase.annotations[0], 1),
],
"annotator_1",
self.sections_lowercase.annotators[0],
),
(
[
(self.sections_annotator_2_uppercase, 0),
(self.sections_annotator_2_lowercase, 1),
(self.sections_uppercase.annotations[1], 0),
(self.sections_lowercase.annotations[1], 1),
],
"annotator_2",
self.sections_lowercase.annotators[1],
),
],
metadata=self._track_metadata,
Expand All @@ -229,7 +280,7 @@ def load_audio(fpath: str) -> Tuple[np.ndarray, float]:


@io.coerce_to_string_io
def load_sections(fhandle: TextIO) -> annotations.SectionData:
def load_sections(fhandle: TextIO) -> Optional[annotations.SectionData]:
"""Load salami sections data from a file

Args:
Expand All @@ -247,6 +298,8 @@ def load_sections(fhandle: TextIO) -> annotations.SectionData:
secs.append(line[1])
times = np.array(times) # type: ignore
secs = np.array(secs) # type: ignore
if len(times) == 0:
return None

# remove sections with length == 0
times_revised = np.delete(times, np.where(np.diff(times) == 0))
Expand Down
23 changes: 23 additions & 0 deletions tests/datasets/test_salami.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def test_track():
}

expected_property_types = {
"sections_uppercase": annotations.MultiAnnotator,
"sections_lowercase": annotations.MultiAnnotator,
"sections_annotator_1_uppercase": annotations.SectionData,
"sections_annotator_1_lowercase": annotations.SectionData,
"sections_annotator_2_uppercase": annotations.SectionData,
Expand Down Expand Up @@ -81,6 +83,12 @@ def test_track():
}

# test that cached properties don't fail and have the expected type
assert type(track.sections_uppercase) is annotations.MultiAnnotator
assert type(track.sections_lowercase) is annotations.MultiAnnotator
assert track.sections_uppercase.annotations[1] is None
assert track.sections_lowercase.annotations[1] is None

# test deprecated attributes for coverage
assert type(track.sections_annotator_1_uppercase) is annotations.SectionData
assert type(track.sections_annotator_1_lowercase) is annotations.SectionData
assert track.sections_annotator_2_uppercase is None
Expand All @@ -104,6 +112,12 @@ def test_track():
}

# test that cached properties don't fail and have the expected type
assert track.sections_uppercase.annotators[0] is None
assert track.sections_lowercase.annotations[0] is None
assert type(track.sections_uppercase) is annotations.MultiAnnotator
assert type(track.sections_lowercase) is annotations.MultiAnnotator

# test deprecated attributes for coverage
assert track.sections_annotator_1_uppercase is None
assert track.sections_annotator_1_lowercase is None
assert type(track.sections_annotator_2_uppercase) is annotations.SectionData
Expand Down Expand Up @@ -201,6 +215,15 @@ def test_load_sections():
section_data_none = salami.load_sections(None)
assert section_data_none is None

# load an empty file
sections_path = (
"tests/resources/mir_datasets/salami/"
+ "salami-data-public-hierarchy-corrections/annotations/2/parsed/textfile1_uppercase_empty.txt"
)

section_data = salami.load_sections(sections_path)
assert section_data is None


def test_load_metadata():
data_home = "tests/resources/mir_datasets/salami"
Expand Down
29 changes: 29 additions & 0 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,35 @@ def __init__(self):
)


def test_multiannotator():
# test good data
annotators = ["annotator_1", "annotator_2"]
labels_1 = ["Vocals", "Guitar"]
labels_2 = ["Vocals", "Drums"]
intervals_1 = np.array([[0.0, 0.1], [0.5, 1.5]])
intervals_2 = np.array([[0.0, 1.0], [0.5, 1.0]])
multi_annot = [
annotations.EventData(intervals_1, "s", labels_1, "open"),
annotations.EventData(intervals_2, "s", labels_2, "open"),
]
events = annotations.MultiAnnotator(annotators, multi_annot, annotations.EventData)

assert events.annotations[0].events == labels_1
assert events.annotators[1] == "annotator_2"
assert np.allclose(events.annotations[1].intervals, intervals_2)

# test bad data
bad_labels = ["Is a", "Number", 5]
pytest.raises(TypeError, annotations.MultiAnnotator, annotators, bad_labels)
pytest.raises(TypeError, annotations.MultiAnnotator, [0, 1], multi_annot)
pytest.raises(
TypeError,
annotations.MultiAnnotator,
annotators,
[["bad", "format"], ["indeed"]],
)


def test_beat_data():
times = np.array([1.0, 2.0])
positions = np.array([3, 4])
Expand Down