diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index dec7bd07d..dd2648524 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -280,6 +280,7 @@ You may find these examples useful as references: - `A dataset with a remote index `_ - `A dataset with extra dependencies `_ - `A dataset which has multitracks `_ + - `A dataset which has multiple annotators `_ For many more examples, see the `datasets folder `_. diff --git a/mirdata/annotations.py b/mirdata/annotations.py index d053540b8..b77abeaa7 100644 --- a/mirdata/annotations.py +++ b/mirdata/annotations.py @@ -79,6 +79,26 @@ def __repr__(self): return repr_str +class MultiAnnotator(object): + """Multiple annotator class. + This class should be used for datasets with multiple annotators (e.g. multiple annotators per track). + + Attributes: + annotators (list): list with annotator ids + annotations (list): list of annotations (e.g. [beat_data1, beat_data2] each with + type BeatData or [chord_data1, chord_data2] each with type chord data + + """ + + def __init__(self, annotators, annotations, dtype) -> None: + validate_array_like(annotators, list, str, none_allowed=True) + validate_array_like(annotations, list, dtype, none_allowed=True) + validate_lengths_equal([annotators, annotations]) + + self.annotators = annotators + self.annotations = annotations + + class BeatData(Annotation): """BeatData class @@ -376,8 +396,9 @@ def to_sparse_index( ] ) - return np.array(index), convert_amplitude_units( - voicing, self.voicing_unit, amplitude_unit + return ( + np.array(index), + convert_amplitude_units(voicing, self.voicing_unit, amplitude_unit), ) def to_matrix( @@ -708,8 +729,9 @@ def to_sparse_index( if t != -1 and f != -1 ] ) - return np.array(index), convert_amplitude_units( - confidence_out, conf_unit, amplitude_unit + return ( + np.array(index), + convert_amplitude_units(confidence_out, conf_unit, amplitude_unit), ) def to_matrix( @@ -1383,7 +1405,9 @@ def closest_index(input_array, target_array): return indexes -def validate_array_like(array_like, expected_type, expected_dtype, none_allowed=False): +def validate_array_like( + array_like, expected_type, expected_dtype, check_child=False, none_allowed=False +): """Validate that array-like object is well formed If array_like is None, validation passes automatically. @@ -1392,11 +1416,12 @@ def validate_array_like(array_like, expected_type, expected_dtype, none_allowed= array_like (array-like): object to validate expected_type (type): expected type, either list or np.ndarray expected_dtype (type): expected dtype + check_child (bool): if True, checks if all elements of array are children of expected_dtype none_allowed (bool): if True, allows array to be None Raises: TypeError: if type/dtype does not match expected_type/expected_dtype - ValueError: if array + ValueError: if array is empty but it shouldn't be """ if array_like is None: @@ -1416,7 +1441,9 @@ def validate_array_like(array_like, expected_type, expected_dtype, none_allowed= ) if expected_type == list and not all( - isinstance(n, expected_dtype) for n in array_like + isinstance(n, expected_dtype) + for n in array_like + if not ((n is None) and none_allowed) ): raise TypeError(f"List elements should all have type {expected_dtype}") diff --git a/mirdata/datasets/salami.py b/mirdata/datasets/salami.py index 71ee9868e..b0b607951 100644 --- a/mirdata/datasets/salami.py +++ b/mirdata/datasets/salami.py @@ -20,6 +20,7 @@ import librosa import numpy as np +import logging from mirdata import annotations, core, download_utils, io, jams_utils @@ -92,6 +93,8 @@ class Track(core.Track): sections_annotator_1_lowercase (SectionData): annotations in hierarchy level 1 from annotator 1 sections_annotator_2_uppercase (SectionData): annotations in hierarchy level 0 from annotator 2 sections_annotator_2_lowercase (SectionData): annotations in hierarchy level 1 from annotator 2 + sections_uppercase (annotations.MultiAnnotator): annotations in hierarchy level 0 + sections_lowercase (annotations.MultiAnnotator): annotations in hierarchy level 1 """ def __init__( @@ -157,20 +160,68 @@ def broad_genre(self): def genre(self): return self._track_metadata.get("genre") + @core.cached_property + def sections_uppercase(self) -> Optional[annotations.MultiAnnotator]: + return annotations.MultiAnnotator( + [ + self._track_metadata.get("annotator_1_id"), + self._track_metadata.get("annotator_2_id"), + ], + [ + load_sections(self.sections_annotator1_uppercase_path), + load_sections(self.sections_annotator2_uppercase_path), + ], + annotations.SectionData, + ) + + @core.cached_property + def sections_lowercase(self) -> Optional[annotations.MultiAnnotator]: + return annotations.MultiAnnotator( + [ + self._track_metadata.get("annotator_1_id"), + self._track_metadata.get("annotator_2_id"), + ], + [ + load_sections(self.sections_annotator1_lowercase_path), + load_sections(self.sections_annotator2_lowercase_path), + ], + annotations.SectionData, + ) + @core.cached_property def sections_annotator_1_uppercase(self) -> Optional[annotations.SectionData]: + logging.warning( + "Deprecation warning: sections_anntotator_1_uppercase is deprecated starting " + "in version 0.3.4b3 and will be removed in a future version. " + "Use sections_uppercase in the future." + ) return load_sections(self.sections_annotator1_uppercase_path) @core.cached_property def sections_annotator_1_lowercase(self) -> Optional[annotations.SectionData]: + logging.warning( + "Deprecation warning: sections_annotator_1_lowercase is deprecated starting " + "in version 0.3.4b3 and will be removed in a future version. " + "Use sections_lowercase in the future." + ) return load_sections(self.sections_annotator1_lowercase_path) @core.cached_property def sections_annotator_2_uppercase(self) -> Optional[annotations.SectionData]: + logging.warning( + "Deprecation warning: sections_anntotator_2_uppercase is deprecated starting " + "in version 0.3.4b3 and will be removed in a future version. " + "Use sections_uppercase in the future." + ) return load_sections(self.sections_annotator2_uppercase_path) @core.cached_property def sections_annotator_2_lowercase(self) -> Optional[annotations.SectionData]: + logging.warning( + "Deprecation warning: sections_annotator_2_lowercase is deprecated starting " + "in version 0.3.4b3 and will be removed in a future version. " + "Use sections_lowercase in the future." + ) return load_sections(self.sections_annotator2_lowercase_path) @property @@ -196,17 +247,17 @@ def to_jams(self): multi_section_data=[ ( [ - (self.sections_annotator_1_uppercase, 0), - (self.sections_annotator_1_lowercase, 1), + (self.sections_uppercase.annotations[0], 0), + (self.sections_lowercase.annotations[0], 1), ], - "annotator_1", + self.sections_lowercase.annotators[0], ), ( [ - (self.sections_annotator_2_uppercase, 0), - (self.sections_annotator_2_lowercase, 1), + (self.sections_uppercase.annotations[1], 0), + (self.sections_lowercase.annotations[1], 1), ], - "annotator_2", + self.sections_lowercase.annotators[1], ), ], metadata=self._track_metadata, @@ -229,7 +280,7 @@ def load_audio(fpath: str) -> Tuple[np.ndarray, float]: @io.coerce_to_string_io -def load_sections(fhandle: TextIO) -> annotations.SectionData: +def load_sections(fhandle: TextIO) -> Optional[annotations.SectionData]: """Load salami sections data from a file Args: @@ -247,6 +298,8 @@ def load_sections(fhandle: TextIO) -> annotations.SectionData: secs.append(line[1]) times = np.array(times) # type: ignore secs = np.array(secs) # type: ignore + if len(times) == 0: + return None # remove sections with length == 0 times_revised = np.delete(times, np.where(np.diff(times) == 0)) diff --git a/tests/datasets/test_salami.py b/tests/datasets/test_salami.py index d6e0ab5cc..543fc8a2e 100644 --- a/tests/datasets/test_salami.py +++ b/tests/datasets/test_salami.py @@ -35,6 +35,8 @@ def test_track(): } expected_property_types = { + "sections_uppercase": annotations.MultiAnnotator, + "sections_lowercase": annotations.MultiAnnotator, "sections_annotator_1_uppercase": annotations.SectionData, "sections_annotator_1_lowercase": annotations.SectionData, "sections_annotator_2_uppercase": annotations.SectionData, @@ -81,6 +83,12 @@ def test_track(): } # test that cached properties don't fail and have the expected type + assert type(track.sections_uppercase) is annotations.MultiAnnotator + assert type(track.sections_lowercase) is annotations.MultiAnnotator + assert track.sections_uppercase.annotations[1] is None + assert track.sections_lowercase.annotations[1] is None + + # test deprecated attributes for coverage assert type(track.sections_annotator_1_uppercase) is annotations.SectionData assert type(track.sections_annotator_1_lowercase) is annotations.SectionData assert track.sections_annotator_2_uppercase is None @@ -104,6 +112,12 @@ def test_track(): } # test that cached properties don't fail and have the expected type + assert track.sections_uppercase.annotators[0] is None + assert track.sections_lowercase.annotations[0] is None + assert type(track.sections_uppercase) is annotations.MultiAnnotator + assert type(track.sections_lowercase) is annotations.MultiAnnotator + + # test deprecated attributes for coverage assert track.sections_annotator_1_uppercase is None assert track.sections_annotator_1_lowercase is None assert type(track.sections_annotator_2_uppercase) is annotations.SectionData @@ -201,6 +215,15 @@ def test_load_sections(): section_data_none = salami.load_sections(None) assert section_data_none is None + # load an empty file + sections_path = ( + "tests/resources/mir_datasets/salami/" + + "salami-data-public-hierarchy-corrections/annotations/2/parsed/textfile1_uppercase_empty.txt" + ) + + section_data = salami.load_sections(sections_path) + assert section_data is None + def test_load_metadata(): data_home = "tests/resources/mir_datasets/salami" diff --git a/tests/resources/mir_datasets/salami/salami-data-public-hierarchy-corrections/annotations/2/parsed/textfile1_uppercase_empty.txt b/tests/resources/mir_datasets/salami/salami-data-public-hierarchy-corrections/annotations/2/parsed/textfile1_uppercase_empty.txt new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 3883a73c4..051054c63 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -27,6 +27,35 @@ def __init__(self): ) +def test_multiannotator(): + # test good data + annotators = ["annotator_1", "annotator_2"] + labels_1 = ["Vocals", "Guitar"] + labels_2 = ["Vocals", "Drums"] + intervals_1 = np.array([[0.0, 0.1], [0.5, 1.5]]) + intervals_2 = np.array([[0.0, 1.0], [0.5, 1.0]]) + multi_annot = [ + annotations.EventData(intervals_1, "s", labels_1, "open"), + annotations.EventData(intervals_2, "s", labels_2, "open"), + ] + events = annotations.MultiAnnotator(annotators, multi_annot, annotations.EventData) + + assert events.annotations[0].events == labels_1 + assert events.annotators[1] == "annotator_2" + assert np.allclose(events.annotations[1].intervals, intervals_2) + + # test bad data + bad_labels = ["Is a", "Number", 5] + pytest.raises(TypeError, annotations.MultiAnnotator, annotators, bad_labels) + pytest.raises(TypeError, annotations.MultiAnnotator, [0, 1], multi_annot) + pytest.raises( + TypeError, + annotations.MultiAnnotator, + annotators, + [["bad", "format"], ["indeed"]], + ) + + def test_beat_data(): times = np.array([1.0, 2.0]) positions = np.array([3, 4])