diff --git a/.github/workflows/audiocraft_linter.yml b/.github/workflows/audiocraft_linter.yml index 812b2aec..60479fa6 100644 --- a/.github/workflows/audiocraft_linter.yml +++ b/.github/workflows/audiocraft_linter.yml @@ -3,7 +3,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: [ main, audiocraft_pub_main ] jobs: run_linter: diff --git a/.github/workflows/audiocraft_tests.yml b/.github/workflows/audiocraft_tests.yml index 829b37aa..e3476361 100644 --- a/.github/workflows/audiocraft_tests.yml +++ b/.github/workflows/audiocraft_tests.yml @@ -3,7 +3,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + branches: [ main, audiocraft_pub_main ] jobs: run_tests: diff --git a/assets/CJ_Beatbox_Loop_05_90.wav b/assets/CJ_Beatbox_Loop_05_90.wav new file mode 100644 index 00000000..2e2d9c5b Binary files /dev/null and b/assets/CJ_Beatbox_Loop_05_90.wav differ diff --git a/assets/chord_to_index_mapping.pkl b/assets/chord_to_index_mapping.pkl new file mode 100644 index 00000000..402669df Binary files /dev/null and b/assets/chord_to_index_mapping.pkl differ diff --git a/assets/salience_1.th b/assets/salience_1.th new file mode 100644 index 00000000..dc0c9c56 Binary files /dev/null and b/assets/salience_1.th differ diff --git a/assets/salience_1.wav b/assets/salience_1.wav new file mode 100644 index 00000000..205e52e2 Binary files /dev/null and b/assets/salience_1.wav differ diff --git a/assets/salience_2.th b/assets/salience_2.th new file mode 100644 index 00000000..8cc0dc11 Binary files /dev/null and b/assets/salience_2.th differ diff --git a/assets/salience_2.wav b/assets/salience_2.wav new file mode 100644 index 00000000..af77e1aa Binary files /dev/null and b/assets/salience_2.wav differ diff --git a/assets/sep_drums_1.mp3 b/assets/sep_drums_1.mp3 new file mode 100644 index 00000000..6aaf0da5 Binary files /dev/null and b/assets/sep_drums_1.mp3 differ diff --git a/audiocraft/data/__init__.py b/audiocraft/data/__init__.py index 2906ff12..fdd35f2b 100644 --- a/audiocraft/data/__init__.py +++ b/audiocraft/data/__init__.py @@ -7,4 +7,4 @@ or also including some metadata.""" # flake8: noqa -from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset +from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset, jasco_dataset diff --git a/audiocraft/data/jasco_dataset.py b/audiocraft/data/jasco_dataset.py new file mode 100644 index 00000000..933c7291 --- /dev/null +++ b/audiocraft/data/jasco_dataset.py @@ -0,0 +1,312 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import bisect +import pickle +import math +import os +import torch +import typing as tp +from pathlib import Path +from dataclasses import dataclass, fields +from ..utils.utils import construct_frame_chords +from .music_dataset import MusicDataset, MusicInfo +from .audio_dataset import load_audio_meta +from ..modules.conditioners import (ConditioningAttributes, SymbolicCondition) +import librosa +import numpy as np + + +@dataclass +class JascoInfo(MusicInfo): + """ + A data class extending MusicInfo for JASCO. The following attributes are added: + Attributes: + frame_chords (Optional[list]): A list of chords associated with frames in the music piece. + """ + chords: tp.Optional[SymbolicCondition] = None + melody: tp.Optional[SymbolicCondition] = None + + def to_condition_attributes(self) -> ConditioningAttributes: + out = ConditioningAttributes() + for _field in fields(self): + key, value = _field.name, getattr(self, _field.name) + if key == 'self_wav': + out.wav[key] = value + elif key in {'chords', 'melody'}: + out.symbolic[key] = value + elif key == 'joint_embed': + for embed_attribute, embed_cond in value.items(): + out.joint_embed[embed_attribute] = embed_cond + else: + if isinstance(value, list): + value = ' '.join(value) + out.text[key] = value + return out + + +class MelodyData: + + SALIENCE_MODEL_EXPECTED_SAMPLE_RATE = 22050 + SALIENCE_MODEL_EXPECTED_HOP_SIZE = 256 + + def __init__(self, + latent_fr: int, + segment_duration: float, + melody_fr: int = 86, + melody_salience_dim: int = 53, + chroma_root: tp.Optional[str] = None, + override_cache: bool = False, + do_argmax: bool = True): + """Module to load salience matrix for a given info. + + Args: + latent_fr (int): latent frame rate to match (interpolates model frame rate accordingly). + segment_duration (float): expected segment duration. + melody_fr (int, optional): extracted salience frame rate. Defaults to 86. + melody_salience_dim (int, optional): salience dim. Defaults to 53. + chroma_root (str, optional): path to root containing salience cache. Defaults to None. + override_cache (bool, optional): rewrite cache. Defaults to False. + do_argmax (bool, optional): argmax the melody matrix. Defaults to True. + """ + + self.segment_duration = segment_duration + self.melody_fr = melody_fr + self.latent_fr = latent_fr + self.melody_salience_dim = melody_salience_dim + self.do_argmax = do_argmax + self.tgt_chunk_len = int(latent_fr * segment_duration) + + self.null_op = False + if chroma_root is None: + self.null_op = True + elif not os.path.exists(f"{chroma_root}/cache.pkl") or override_cache: + self.tracks = [] + for file in librosa.util.find_files(chroma_root, ext='txt'): + with open(file, 'r') as f: + lines = f.readlines() + for line in lines: + self.tracks.append(line.strip()) + + # go over tracks and add the corresponding saliency file to self.saliency_files + self.saliency_files = [] + for track in self.tracks: + # saliency file name + salience_file = f"{chroma_root}/{track.split('/')[-1].split('.')[0]}_multif0_salience.npz" + assert os.path.exists(salience_file), f"File {salience_file} does not exist" + self.saliency_files.append(salience_file) + + self.trk2idx = {trk.split('/')[-1].split('.')[0]: i for i, trk in enumerate(self.tracks)} + torch.save({'tracks': self.tracks, + 'saliency_files': self.saliency_files, + 'trk2idx': self.trk2idx}, f"{chroma_root}/cache.pkl") + else: + tmp = torch.load(f"{chroma_root}/cache.pkl") + self.tracks = tmp['tracks'] + self.saliency_files = tmp['saliency_files'] + self.trk2idx = tmp['trk2idx'] + self.model_frame_rate = int(self.SALIENCE_MODEL_EXPECTED_SAMPLE_RATE / self.SALIENCE_MODEL_EXPECTED_HOP_SIZE) + + def load_saliency_from_saliency_dict(self, + saliency_dict: tp.Dict[str, tp.Any], + offset: float) -> torch.Tensor: + """ + construct the salience matrix and perform linear interpolation w.r.t the temporal axis to match the expected + frame rate. + """ + # get saliency map for the segment + saliency_dict_ = {} + l, r = int(offset * self.model_frame_rate), int((offset + self.segment_duration) * self.model_frame_rate) + saliency_dict_['salience'] = saliency_dict['salience'][:, l: r].T + saliency_dict_['times'] = saliency_dict['times'][l: r] - offset + saliency_dict_['freqs'] = saliency_dict['freqs'] + + saliency_dict_['salience'] = torch.Tensor(saliency_dict_['salience']).float().permute(1, 0) # C, T + if saliency_dict_['salience'].shape[-1] <= int(self.model_frame_rate) / self.latent_fr: # empty chroma + saliency_dict_['salience'] = torch.zeros((saliency_dict_['salience'].shape[0], self.tgt_chunk_len)) + else: + salience = torch.nn.functional.interpolate(saliency_dict_['salience'].unsqueeze(0), + scale_factor=self.latent_fr/int(self.model_frame_rate), + mode='linear').squeeze(0) + if salience.shape[-1] < self.tgt_chunk_len: + salience = torch.nn.functional.pad(salience, + (0, self.tgt_chunk_len - salience.shape[-1]), + mode='constant', + value=0) + elif salience.shape[-1] > self.tgt_chunk_len: + salience = salience[..., :self.tgt_chunk_len] + saliency_dict_['salience'] = salience + + salience = saliency_dict_['salience'] + if self.do_argmax: + binary_mask = torch.zeros_like(salience) + binary_mask[torch.argmax(salience, dim=0), torch.arange(salience.shape[-1])] = 1 + binary_mask *= (salience != 0).float() + salience = binary_mask + return salience + + def get_null_salience(self) -> torch.Tensor: + return torch.zeros((self.melody_salience_dim, self.tgt_chunk_len)) + + def __call__(self, x: MusicInfo) -> torch.Tensor: + """Reads salience matrix from memory, shifted by seek time + + Args: + x (MusicInfo): Music info of a single sample + + Returns: + torch.Tensor: salience matrix matching the target info + """ + fname: str = x.meta.path.split("/")[-1].split(".")[0] if x.meta.path is not None else "" + if x.meta.path is None or x.meta.path == "" or fname not in self.trk2idx: + salience = self.get_null_salience() + else: + assert fname in self.trk2idx, f"Track {fname} not found in the cache" + idx = self.trk2idx[fname] + saliency_dict = np.load(self.saliency_files[idx], allow_pickle=True) + salience = self.load_saliency_from_saliency_dict(saliency_dict, x.seek_time) + return salience + + +class JascoDataset(MusicDataset): + """JASCO dataset is a MusicDataset with jasco-related symbolic data (chords, melody). + + Args: + chords_card (int): The cardinality of the chords, default is 194. + compression_model_framerate (int): The framerate for the compression model, default is 50. + + See `audiocraft.data.info_audio_dataset.MusicDataset` for full initialization arguments. + """ + @classmethod + def from_meta(cls, root: tp.Union[str, Path], **kwargs): + """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file. + + Args: + root (str or Path): Path to root folder containing audio files. + kwargs: Additional keyword arguments for the AudioDataset. + """ + root = Path(root) + # a directory is given + if root.is_dir(): + if (root / 'data.jsonl').exists(): + meta_json = root / 'data.jsonl' + elif (root / 'data.jsonl.gz').exists(): + meta_json = root / 'data.jsonl.gz' + else: + raise ValueError("Don't know where to read metadata from in the dir. " + "Expecting either a data.jsonl or data.jsonl.gz file but none found.") + # jsonl file was specified + else: + assert root.exists() and root.suffix == '.jsonl', \ + "Either specified path not exist or it is not a jsonl format" + meta_json = root + root = root.parent + meta = load_audio_meta(meta_json) + kwargs['root'] = root + return cls(meta, **kwargs) + + def __init__(self, *args, + chords_card: int = 194, + compression_model_framerate: float = 50., + melody_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = {}, + **kwargs): + """Dataset class for text-to-music generation with temporal controls as in + (JASCO)[https://arxiv.org/pdf/2406.10970] + + Args: + chords_card (int, optional): Number of chord ebeddings. Defaults to 194. + compression_model_framerate (float, optional): Expected frame rate of the resulted latent. Defaults to 50. + melody_kwargs (tp.Optional[tp.Dict[str, tp.Any]], optional): See MelodyData class. Defaults to {}. + """ + root = kwargs.pop('root') + super().__init__(*args, **kwargs) + + chords_mapping_path = root / 'chord_to_index_mapping.pkl' + chords_path = root / 'chords_per_track.pkl' + self.mapping_dict = pickle.load(open(chords_mapping_path, "rb")) if \ + os.path.exists(chords_mapping_path) else None + + self.chords_per_track = pickle.load(open(chords_path, "rb")) if \ + os.path.exists(chords_path) else None + + self.compression_model_framerate = compression_model_framerate + self.null_chord_idx = chords_card + + self.melody_module = MelodyData(**melody_kwargs) # type: ignore + + def _get_relevant_sublist(self, chords, timestamp): + """ + Returns the sublist of chords within the specified timestamp and segment length. + + Args: + chords (list): A sorted list of tuples containing (time changed, chord). + timestamp (float): The timestamp at which to start the sublist. + + Returns: + list: A list of chords within the specified timestamp and segment length. + """ + end_time = timestamp + self.segment_duration + + # Use binary search to find the starting index of the relevant sublist + start_index = bisect.bisect_left(chords, (timestamp,)) + + if start_index != 0: + prev_chord = chords[start_index - 1] + else: + prev_chord = (0.0, "N") + + relevant_chords = [] + + for time_changed, chord in chords[start_index:]: + if time_changed >= end_time: + break + relevant_chords.append((time_changed, chord)) + + return relevant_chords, prev_chord + + def _get_chords(self, music_info: MusicInfo, effective_segment_dur: float) -> torch.Tensor: + if self.chords_per_track is None: + # use null chord when there's no chords in dataset + seq_len = math.ceil(self.compression_model_framerate * effective_segment_dur) + return torch.ones(seq_len, dtype=int) * self.null_chord_idx # type: ignore + + fr = self.compression_model_framerate + + idx = music_info.meta.path.split("/")[-1].split(".")[0] + chords = self.chords_per_track[idx] + + min_timestamp = music_info.seek_time + + chords = [(item[1], item[0]) for item in chords] + chords, prev_chord = self._get_relevant_sublist( + chords, min_timestamp + ) + + iter_min_timestamp = int(min_timestamp * fr) + 1 + + frame_chords = construct_frame_chords( + iter_min_timestamp, chords, self.mapping_dict, prev_chord[1], # type: ignore + fr, self.segment_duration # type: ignore + ) + + return torch.tensor(frame_chords) + + def __getitem__(self, index): + wav, music_info = super().__getitem__(index) + assert not wav.isinfinite().any(), f"inf detected in wav file: {music_info}" + wav = wav.float() + + # downcast music info to jasco info + jasco_info = JascoInfo({k: v for k, v in music_info.__dict__.items()}) + + # get chords + effective_segment_dur = (wav.shape[-1] / self.sample_rate) if \ + self.segment_duration is None else self.segment_duration + frame_chords = self._get_chords(music_info, effective_segment_dur) + jasco_info.chords = SymbolicCondition(frame_chords=frame_chords) + + # get melody + jasco_info.melody = SymbolicCondition(melody=self.melody_module(music_info)) + return wav, jasco_info diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py index 3f485fed..b7032ea5 100644 --- a/audiocraft/models/__init__.py +++ b/audiocraft/models/__init__.py @@ -14,8 +14,10 @@ from .audiogen import AudioGen from .lm import LMModel from .lm_magnet import MagnetLMModel +from .flow_matching import FlowMatchingModel from .multibanddiffusion import MultiBandDiffusion from .musicgen import MusicGen from .magnet import MAGNeT from .unet import DiffusionUnet from .watermark import WMModel +from .jasco import JASCO diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py index 85a1ca53..1ed3d369 100644 --- a/audiocraft/models/builders.py +++ b/audiocraft/models/builders.py @@ -24,15 +24,19 @@ ParallelPatternProvider, UnrolledPatternProvider) from ..modules.conditioners import (BaseConditioner, ChromaStemConditioner, - CLAPEmbeddingConditioner, ConditionFuser, + CLAPEmbeddingConditioner, + ConditionFuser, JascoCondConst, ConditioningProvider, LUTConditioner, T5Conditioner, StyleConditioner) +from ..modules.jasco_conditioners import (JascoConditioningProvider, ChordsEmbConditioner, + DrumsConditioner, MelodyConditioner) from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor from ..utils.utils import dict_from_config from .encodec import (CompressionModel, EncodecModel, InterleaveStereoCompressionModel) from .lm import LMModel from .lm_magnet import MagnetLMModel +from .flow_matching import FlowMatchingModel from .unet import DiffusionUnet from .watermark import WMModel @@ -87,6 +91,48 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: raise KeyError(f"Unexpected compression model {cfg.compression_model}") +def get_jasco_model(cfg: omegaconf.DictConfig, + compression_model: tp.Optional[CompressionModel] = None) -> FlowMatchingModel: + kwargs = dict_from_config(getattr(cfg, "transformer_lm")) + attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout")) + cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance")) + cfg_prob = cls_free_guidance["training_dropout"] + cfg_coef = cls_free_guidance["inference_coef"] + fuser = get_condition_fuser(cfg) + condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) + if JascoCondConst.DRM.value in condition_provider.conditioners: # use self_wav for drums + assert compression_model is not None + + # use compression model for drums conditioning + condition_provider.conditioners.self_wav.compression_model = compression_model + condition_provider.conditioners.self_wav.compression_model.requires_grad_(False) + + # downcast to jasco conditioning provider + seq_len = cfg.compression_model_framerate * cfg.dataset.segment_duration + chords_card = cfg.conditioners.chords.chords_emb.card if JascoCondConst.CRD.value in cfg.conditioners else -1 + condition_provider = JascoConditioningProvider(device=condition_provider.device, + conditioners=condition_provider.conditioners, + chords_card=chords_card, + sequence_length=seq_len) + + if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically + kwargs["cross_attention"] = True + + kwargs.pop("n_q", None) + kwargs.pop("card", None) + + return FlowMatchingModel( + condition_provider=condition_provider, + fuser=fuser, + cfg_dropout=cfg_prob, + cfg_coef=cfg_coef, + attribute_dropout=attribute_dropout, + dtype=getattr(torch, cfg.dtype), + device=cfg.device, + **kwargs, + ).to(cfg.device) + + def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: """Instantiate a transformer LM.""" if cfg.lm_model in ["transformer_lm", "transformer_lm_magnet"]: @@ -157,6 +203,12 @@ def get_conditioner_provider( conditioners[str(cond)] = ChromaStemConditioner( output_dim=output_dim, duration=duration, device=device, **model_args ) + elif model_type in {"chords_emb", "drum_latents", "melody"}: + conditioners_classes = {"chords_emb": ChordsEmbConditioner, + "drum_latents": DrumsConditioner, + "melody": MelodyConditioner} + conditioner_class = conditioners_classes[model_type] + conditioners[str(cond)] = conditioner_class(device=device, **model_args) elif model_type == "clap": conditioners[str(cond)] = CLAPEmbeddingConditioner( output_dim=output_dim, device=device, **model_args @@ -178,8 +230,8 @@ def get_conditioner_provider( def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: """Instantiate a condition fuser object.""" fuser_cfg = getattr(cfg, "fuser") - fuser_methods = ["sum", "cross", "prepend", "input_interpolate"] - fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} + fuser_methods = ["sum", "cross", "prepend", "ignore", "input_interpolate"] + fuse2cond = {k: fuser_cfg[k] for k in fuser_methods if k in fuser_cfg} kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) return fuser diff --git a/audiocraft/models/flow_matching.py b/audiocraft/models/flow_matching.py new file mode 100644 index 00000000..1a8dd3cc --- /dev/null +++ b/audiocraft/models/flow_matching.py @@ -0,0 +1,516 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from functools import partial +import logging +import math +import typing as tp +import torch +from torch import nn +from torchdiffeq import odeint # type: ignore +from ..modules.streaming import StreamingModule +from ..modules.transformer import create_norm_fn, StreamingTransformerLayer +from ..modules.unet_transformer import UnetTransformer +from ..modules.conditioners import ( + ConditionFuser, + ClassifierFreeGuidanceDropout, + AttributeDropout, + ConditioningAttributes, + JascoCondConst +) +from ..modules.jasco_conditioners import JascoConditioningProvider +from ..modules.activations import get_activation_fn + +from .lm import ConditionTensors, init_layer + + +logger = logging.getLogger(__name__) + + +@dataclass +class FMOutput: + latents: torch.Tensor # [B, T, D] + mask: torch.Tensor # [B, T] + + +class CFGTerm: + """ + Base class for Multi Source Classifier-Free Guidance (CFG) terms. This class represents a term in the CFG process, + which is used to guide the generation process by adjusting the influence of different conditions. + Attributes: + conditions (dict): A dictionary of conditions that influence the generation process. + weight (float): The weight of the CFG term, determining its influence on the generation. + """ + def __init__(self, conditions, weight): + self.conditions = conditions + self.weight = weight + + def drop_irrelevant_conds(self, conditions): + """ + Drops irrelevant conditions from the CFG term. This method should be implemented by subclasses. + Args: + conditions (dict): The conditions to be filtered. + Raises: + NotImplementedError: If the method is not implemented in a subclass. + """ + raise NotImplementedError("No base implementation for setting generation params.") + + +class AllCFGTerm(CFGTerm): + """ + A CFG term that retains all conditions. This class does not drop any condition. + """ + def __init__(self, conditions, weight): + super().__init__(conditions, weight) + self.drop_irrelevant_conds() + + def drop_irrelevant_conds(self): + pass + + +class NullCFGTerm(CFGTerm): + """ + A CFG term that drops all conditions, effectively nullifying their influence. + """ + def __init__(self, conditions, weight): + super().__init__(conditions, weight) + self.drop_irrelevant_conds() + + def drop_irrelevant_conds(self): + """ + Drops all conditions by applying a dropout with probability 1.0, effectively nullifying their influence. + """ + self.conditions = ClassifierFreeGuidanceDropout(p=1.0)( + samples=self.conditions, + cond_types=["wav", "text", "symbolic"]) + + +class TextCFGTerm(CFGTerm): + """ + A CFG term that selectively drops conditions based on specified dropout probabilities for different types + of conditions, such as 'symbolic' and 'wav'. + """ + def __init__(self, conditions, weight, model_att_dropout): + """ + Initializes a TextCFGTerm with specified conditions, weight, and model attention dropout configuration. + Args: + conditions (dict): The conditions to be used in the CFG process. + weight (float): The weight of the CFG term. + model_att_dropout (object): The attribute dropouts used by the model. + """ + super().__init__(conditions, weight) + if 'symbolic' in model_att_dropout.p: + self.drop_symbolics = {k: 1.0 for k in model_att_dropout.p['symbolic'].keys()} + else: + self.drop_symbolics = {} + if 'wav' in model_att_dropout.p: + self.drop_wav = {k: 1.0 for k in model_att_dropout.p['wav'].keys()} + else: + self.drop_wav = {} + self.drop_irrelevant_conds() + + def drop_irrelevant_conds(self): + self.conditions = AttributeDropout({'symbolic': self.drop_symbolics, + 'wav': self.drop_wav})(self.conditions) # drop temporal conds + + +class FlowMatchingModel(StreamingModule): + """ + A flow matching model inherits from StreamingModule. + This model uses a transformer architecture to process and fuse conditions, applying learned embeddings and + transformations and predicts multi-source guided vector fields. + Attributes: + condition_provider (JascoConditioningProvider): Provider for conditioning attributes. + fuser (ConditionFuser): Fuser for combining multiple conditions. + dim (int): Dimensionality of the model's main features. + num_heads (int): Number of attention heads in the transformer. + flow_dim (int): Dimensionality of the flow features. + chords_dim (int): Dimensionality for chord embeddings, if used. + drums_dim (int): Dimensionality for drums embeddings, if used. + melody_dim (int): Dimensionality for melody embeddings, if used. + hidden_scale (int): Scaling factor for the dimensionality of the feedforward network in the transformer. + norm (str): Type of normalization to use ('layer_norm' or other supported types). + norm_first (bool): Whether to apply normalization before other operations in the transformer layers. + bias_proj (bool): Whether to include bias in the projection layers. + weight_init (Optional[str]): Method for initializing weights. + depthwise_init (Optional[str]): Method for initializing depthwise convolutional layers. + zero_bias_init (bool): Whether to initialize biases to zero. + cfg_dropout (float): Dropout rate for configuration settings. + cfg_coef (float): Coefficient for configuration influence. + attribute_dropout (Dict[str, Dict[str, float]]): Dropout rates for specific attributes. + time_embedding_dim (int): Dimensionality of time embeddings. + **kwargs: Additional keyword arguments for the transformer. + Methods: + __init__: Initializes the model with the specified attributes and configuration. + """ + def __init__(self, condition_provider: JascoConditioningProvider, + fuser: ConditionFuser, + dim: int = 128, + num_heads: int = 8, + flow_dim: int = 128, + chords_dim: int = 0, + drums_dim: int = 0, + melody_dim: int = 0, + hidden_scale: int = 4, + norm: str = 'layer_norm', + norm_first: bool = False, + bias_proj: bool = True, + weight_init: tp.Optional[str] = None, + depthwise_init: tp.Optional[str] = None, + zero_bias_init: bool = False, + cfg_dropout: float = 0, + cfg_coef: float = 1.0, + attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, + time_embedding_dim: int = 128, + **kwargs): + super().__init__() + self.cfg_coef = cfg_coef + + self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) + self.att_dropout = AttributeDropout(p=attribute_dropout) + self.condition_provider = condition_provider + self.fuser = fuser + self.dim = dim # transformer dim + self.flow_dim = flow_dim + self.chords_dim = chords_dim + self.emb = nn.Linear(flow_dim + chords_dim + drums_dim + melody_dim, dim, bias=False) + if 'activation' in kwargs: + kwargs['activation'] = get_activation_fn(kwargs['activation']) + + self.transformer = UnetTransformer( + d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), + norm=norm, norm_first=norm_first, + layer_class=StreamingTransformerLayer, + **kwargs) + self.out_norm: tp.Optional[nn.Module] = None + if norm_first: + self.out_norm = create_norm_fn(norm, dim) + self.linear = nn.Linear(dim, flow_dim, bias=bias_proj) + self._init_weights(weight_init, depthwise_init, zero_bias_init) + self._fsdp: tp.Optional[nn.Module] + self.__dict__['_fsdp'] = None + + # init time parameter embedding + self.d_temb1 = time_embedding_dim + self.d_temb2 = 4 * time_embedding_dim + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.d_temb1, + self.d_temb2), + torch.nn.Linear(self.d_temb2, + self.d_temb2), + ]) + self.temb_proj = nn.Linear(self.d_temb2, dim) + + def _get_timestep_embedding(self, timesteps, embedding_dim): + """ + ####################################################################################################### + TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py + ####################################################################################################### + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + def _embed_time_parameter(self, t: torch.Tensor): + """ + ####################################################################################################### + TAKEN FROM: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/model.py + ####################################################################################################### + """ + temb = self._get_timestep_embedding(t.flatten(), self.d_temb1) + temb = self.temb.dense[0](temb) + temb = temb * torch.sigmoid(temb) # swish activation + temb = self.temb.dense[1](temb) + return temb + + def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): + """Initialization of the transformer module weights. + + Args: + weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. + depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: + 'current' where the depth corresponds to the current layer index or 'global' where the total number + of layer is used as depth. If not set, no depthwise initialization strategy is used. + zero_bias_init (bool): Whether to initialize bias to zero or not. + """ + assert depthwise_init is None or depthwise_init in ['current', 'global'] + assert depthwise_init is None or weight_init is not None, \ + "If 'depthwise_init' is defined, a 'weight_init' method should be provided." + assert not zero_bias_init or weight_init is not None, \ + "If 'zero_bias_init', a 'weight_init' method should be provided" + + if weight_init is None: + return + + init_layer(self.emb, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) + + for layer_idx, tr_layer in enumerate(self.transformer.layers): + depth = None + if depthwise_init == 'current': + depth = layer_idx + 1 + elif depthwise_init == 'global': + depth = len(self.transformer.layers) + init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) + tr_layer.apply(init_fn) + + init_layer(self.linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) + + def _align_seq_length(self, + cond: torch.Tensor, + seq_len: int = 500): + # trim if needed + cond = cond[:, :seq_len, :] + + # pad if needed + B, T, C = cond.shape + if T < seq_len: + cond = torch.cat((cond, torch.zeros((B, seq_len - T, C), dtype=cond.dtype, device=cond.device)), dim=1) + + return cond + + def forward(self, + latents: torch.Tensor, + t: torch.Tensor, + conditions: tp.List[ConditioningAttributes], + condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor: + """Apply flow matching forward pass on latents and conditions. + Given a tensor of noisy latents of shape [B, T, D] with D the flow dim and T the sequence steps, + and a time parameter tensor t, return the vector field with shape [B, T, D]. + + Args: + latents (torch.Tensor): noisy latents. + conditions (list of ConditioningAttributes): Conditions to use when modeling + the given codes. Note that when evaluating multiple time with the same conditioning + you should pre-compute those and pass them as `condition_tensors`. + condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning + tensors, see `conditions`. + Returns: + torch.Tensor: estimated vector field v_theta. + """ + assert condition_tensors is not None, "FlowMatchingModel require pre-calculation of condition tensors" + assert not conditions, "Shouldn't pass unprocessed conditions to FlowMatchingModel." + + B, T, D = latents.shape + x = latents + + # concat temporal conditions on the feature dimension + temporal_conds = JascoCondConst.ALL.value + for cond in temporal_conds: + if cond not in condition_tensors: + continue + c = self._align_seq_length(condition_tensors[cond][0], seq_len=T) + x = torch.concat((x, c), dim=-1) + + # project to transformer dimension + input_ = self.emb(x) + + input_, cross_attention_input = self.fuser(input_, condition_tensors) + + # embed time parameter + t_embs = self._embed_time_parameter(t) + + # add it to cross_attention_input + cross_attention_input = cross_attention_input + self.temb_proj(t_embs[:, None, :]) + + out = self.transformer(input_, cross_attention_src=cross_attention_input) + + if self.out_norm: + out = self.out_norm(out) + v_theta = self.linear(out) # [B, T, D] + + # remove the prefix from the model outputs + if len(self.fuser.fuse2cond['prepend']) > 0: + v_theta = v_theta[:, :, -T:] + + return v_theta # [B, T, D] + + def _multi_source_cfg_preprocess(self, + conditions: tp.List[ConditioningAttributes], + cfg_coef_all: float, + cfg_coef_txt: float, + min_weight: float = 1e-6): + """ + Preprocesses the CFG terms for multi-source conditional generation. + Args: + conditions (list): A list of conditions to be applied. + cfg_coef_all (float): The coefficient for all conditions. + cfg_coef_txt (float): The coefficient for text conditions. + min_weight (float): The minimal absolute weight for calculating a CFG term. + Returns: + tuple: A tuple containing condition_tensors and cfg_terms. + condition_tensors is a dictionary or ConditionTensors object with tokenized conditions. + cfg_terms is a list of CFGTerm objects with weights adjusted based on the coefficients. + """ + condition_tensors: tp.Optional[ConditionTensors] + cfg_terms = [] + if conditions: + # conditional terms + cfg_terms = [AllCFGTerm(conditions=conditions, weight=cfg_coef_all), + TextCFGTerm(conditions=conditions, weight=cfg_coef_txt, + model_att_dropout=self.att_dropout)] + + # add null term + cfg_terms.append(NullCFGTerm(conditions=conditions, weight=1 - sum([ct.weight for ct in cfg_terms]))) + + # remove terms with negligible weight + for ct in cfg_terms: + if abs(ct.weight) < min_weight: + cfg_terms.remove(ct) + + conds: tp.List[ConditioningAttributes] = sum([ct.conditions for ct in cfg_terms], []) + tokenized = self.condition_provider.tokenize(conds) + condition_tensors = self.condition_provider(tokenized) + else: + condition_tensors = {} + + return condition_tensors, cfg_terms + + def estimated_vector_field(self, z, t, condition_tensors=None, cfg_terms=[]): + """ + Estimates the vector field for the given latent variables and time parameter, + conditioned on the provided conditions. + Args: + z (Tensor): The latent variables. + t (float): The time variable. + condition_tensors (ConditionTensors, optional): The condition tensors. Defaults to None. + cfg_terms (list, optional): The list of CFG terms. Defaults to an empty list. + Returns: + Tensor: The estimated vector field. + """ + if len(cfg_terms) > 1: + z = z.repeat(len(cfg_terms), 1, 1) # duplicate noisy latents for multi-source CFG + v_thetas = self(latents=z, t=t, conditions=[], condition_tensors=condition_tensors) + return self._multi_source_cfg_postprocess(v_thetas, cfg_terms) + + def _multi_source_cfg_postprocess(self, v_thetas, cfg_terms): + """ + Postprocesses the vector fields generated for each CFG term to combine them into a single vector field. + Multi source guidance occurs here. + Args: + v_thetas (Tensor): The vector fields for each CFG term. + cfg_terms (list): The CFG terms used. + Returns: + Tensor: The combined vector field. + """ + if len(cfg_terms) <= 1: + return v_thetas + v_theta_per_term = v_thetas.chunk(len(cfg_terms)) + return sum([ct.weight * term_vf for ct, term_vf in zip(cfg_terms, v_theta_per_term)]) + + @torch.no_grad() + def generate(self, + prompt: tp.Optional[torch.Tensor] = None, + conditions: tp.List[ConditioningAttributes] = [], + num_samples: tp.Optional[int] = None, + max_gen_len: int = 256, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + cfg_coef_all: float = 3.0, + cfg_coef_txt: float = 1.0, + euler: bool = False, + euler_steps: int = 100, + ode_rtol: float = 1e-5, + ode_atol: float = 1e-5, + ) -> torch.Tensor: + """ + Generate audio latents given a prompt or unconditionally. This method supports both Euler integration + and adaptive ODE solving to generate sequences based on the specified conditions and configuration coefficients. + + Args: + prompt (torch.Tensor, optional): Initial prompt to condition the generation. defaults to None + conditions (List[ConditioningAttributes]): List of conditioning attributes - text, symbolic or audio. + num_samples (int, optional): Number of samples to generate. + If None, it is inferred from the number of conditions. + max_gen_len (int): Maximum length of the generated sequence. + callback (Callable[[int, int], None], optional): Callback function to monitor the generation process. + cfg_coef_all (float): Coefficient for the fully conditional CFG term. + cfg_coef_txt (float): Coefficient for text CFG term. + euler (bool): If True, use Euler integration, otherwise use adaptive ODE solver. + euler_steps (int): Number of Euler steps to perform if Euler integration is used. + ode_rtol (float): ODE solver rtol threshold. + ode_atol (float): ODE solver atol threshold. + + Returns: + torch.Tensor: Generated latents, shaped as (num_samples, max_gen_len, feature_dim). + """ + + assert not self.training, "generation shouldn't be used in training mode." + first_param = next(iter(self.parameters())) + device = first_param.device + + # Checking all input shapes are consistent. + possible_num_samples = [] + if num_samples is not None: + possible_num_samples.append(num_samples) + elif prompt is not None: + possible_num_samples.append(prompt.shape[0]) + elif conditions: + possible_num_samples.append(len(conditions)) + else: + possible_num_samples.append(1) + assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" + num_samples = possible_num_samples[0] + + condition_tensors, cfg_terms = self._multi_source_cfg_preprocess(conditions, cfg_coef_all, cfg_coef_txt) + + # flow matching inference + B, T, D = num_samples, max_gen_len, self.flow_dim + + z_0 = torch.randn((B, T, D), device=device) + + if euler: + # vanilla Euler intergration + dt = (1 / euler_steps) + z = z_0 + t = torch.zeros((1, ), device=device) + for _ in range(euler_steps): + v_theta = self.estimated_vector_field(z, t, + condition_tensors=condition_tensors, + cfg_terms=cfg_terms) + z = z + dt * v_theta + t = t + dt + z_1 = z + else: + # solve with dynamic ode integrator (dopri5) + t = torch.tensor([0, 1.0 - 1e-5], device=device) + num_evals = 0 + + # define ode vector field function + def inner_ode_func(t, z): + nonlocal num_evals + num_evals += 1 + if callback is not None: + ESTIMATED_ODE_SOLVER_STEPS = 300 + callback(num_evals, ESTIMATED_ODE_SOLVER_STEPS) + return self.estimated_vector_field(z, t, + condition_tensors=condition_tensors, + cfg_terms=cfg_terms) + + ode_opts: dict = {"options": {}} + z = odeint( + inner_ode_func, + z_0, + t, + **{"atol": ode_atol, "rtol": ode_rtol, **ode_opts}, + ) + logger.info("Generated in %d steps", num_evals) + z_1 = z[-1] + + return z_1 diff --git a/audiocraft/models/jasco.py b/audiocraft/models/jasco.py new file mode 100644 index 00000000..0a7bf7f1 --- /dev/null +++ b/audiocraft/models/jasco.py @@ -0,0 +1,326 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Main model for using JASCO. This will combine all the required components +and provide easy access to the generation API. +""" +import os +import math +import pickle +import torch +import typing as tp + +from audiocraft.utils.utils import construct_frame_chords +from .genmodel import BaseGenModel +from .loaders import load_compression_model, load_jasco_model +from ..data.audio_utils import convert_audio +from ..modules.conditioners import WavCondition, ConditioningAttributes, SymbolicCondition, JascoCondConst + + +class JASCO(BaseGenModel): + """JASCO main model with convenient generation API. + Args: + chords_mapping_path: path to chords to index mapping pickle + kwargs - See MusicGen class. + """ + def __init__(self, chords_mapping_path='assets/chord_to_index_mapping.pkl', **kwargs): + super().__init__(**kwargs) + # JASCO operates over a fixed sequence length defined in it's config. + self.duration = self.lm.cfg.dataset.segment_duration + + # load chord2index mapping of Chordino (https://github.com/ohollo/chord-extractor) + assert os.path.exists(chords_mapping_path) + self.chords_mapping = pickle.load(open(chords_mapping_path, "rb")) + + # set generation parameters + self.set_generation_params() + + @staticmethod + def get_pretrained(name: str = 'facebook/jasco-chords-drums-400M', device=None, + chords_mapping_path='assets/chord_to_index_mapping.pkl'): + """Return pretrained model, we provide 2 models: + 1. facebook/jasco-chords-drums-400M: 10s music generation conditioned on + text, chords and drums, 400M parameters. + 2. facebook/jasco-chords-drums-1B: 10s music generation conditioned on + text, chords and drums, 1B parameters. + """ + if device is None: + if torch.cuda.device_count(): + device = 'cuda' + else: + device = 'cpu' + + compression_model = load_compression_model(name, device=device) + lm = load_jasco_model(name, compression_model, device=device) + + kwargs = {'name': name, + 'compression_model': compression_model, + 'lm': lm, + 'chords_mapping_path': chords_mapping_path} + return JASCO(**kwargs) + + def set_generation_params(self, + cfg_coef_all: float = 5.0, + cfg_coef_txt: float = 0.0, + **kwargs): + """Set the generation parameters for JASCO. + + Args: + cfg_coef_all (float, optional): Coefficient used in multi-source classifier free guidance - + all conditions term. Defaults to 5.0. + cfg_coef_txt (float, optional): Coefficient used in multi-source classifier free guidance - + text condition term. Defaults to 0.0. + + """ + self.generation_params = { + 'cfg_coef_all': cfg_coef_all, + 'cfg_coef_txt': cfg_coef_txt + } + self.generation_params.update(kwargs) + + def _unnormalized_latents(self, latents: torch.Tensor) -> torch.Tensor: + """Unnormalize latents, shifting back to EnCodec's expected mean, std""" + assert self.cfg is not None + scaled = latents * self.cfg.compression_model_latent_std + return scaled + self.cfg.compression_model_latent_mean + + def generate_audio(self, gen_latents: torch.Tensor) -> torch.Tensor: + """Decode audio from generated latents""" + assert gen_latents.dim() == 3 # [B, T, C] + + # unnormalize latents + gen_latents = self._unnormalized_latents(gen_latents) + return self.compression_model.model.decoder(gen_latents.permute(0, 2, 1)) + + def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], + prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: + """Generate continuous audio latents given conditions. + + Args: + attributes (list of ConditioningAttributes): Conditions used for generation (here text). + prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + Returns: + torch.Tensor: Generated latents, of shape [B, T, C]. + """ + total_gen_len = int(self.duration * self.frame_rate) + max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) + + def _progress_callback(ode_steps: int, max_ode_steps: int): + ode_steps += 1 + if self._progress_callback is not None: + # Note that total_gen_len might be quite wrong depending on the + # codebook pattern used, but with delay it is almost accurate. + self._progress_callback(ode_steps, max_ode_steps) + else: + print(f'{ode_steps: 6d} / {max_ode_steps: 6d}', end='\r') + + if prompt_tokens is not None: + assert max_prompt_len >= prompt_tokens.shape[-1], \ + "Prompt is longer than audio to generate" + + callback = None + if progress: + callback = _progress_callback + + # generate by sampling from the LM + with self.autocast: + total_gen_len = math.ceil(self.duration * self.compression_model.frame_rate) + return self.lm.generate( + prompt_tokens, attributes, + callback=callback, max_gen_len=total_gen_len, **self.generation_params) + + def _prepare_chord_conditions( + self, + attributes: tp.List[ConditioningAttributes], + chords: tp.Optional[tp.List[tp.Tuple[str, float]]], + ) -> tp.List[ConditioningAttributes]: + """ + Prepares chord conditions by translating symbolic chord progressions into a sequence of integers. + This method updates the ConditioningAttributes with per-frame chords information. + Args: + attributes (List[ConditioningAttributes]): + The initial attributes and optional tensor data. + chords (List[Tuple[str, float]]): + A list of tuples containing chord labels and their start times. + Returns: + List[ConditioningAttributes]: + The updated attributes with frame chords integrated, alongside the original optional tensor data. + """ + if chords is None or chords == []: + for att in attributes: + att.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(frame_chords=-1 * + torch.ones(1, dtype=torch.int32)) + return attributes + + # flip from (chord, start_time) to (start_time, chord) + chords_time_first: tp.List[tuple[float, str]] = [(item[1], item[0]) for item in chords] + + # translate symbolic chord progression into a sequence of ints + frame_chords = construct_frame_chords(min_timestamp=0, + chord_changes=chords_time_first, + mapping_dict=self.chords_mapping, + prev_chord='', + frame_rate=self.compression_model.frame_rate, + segment_duration=self.duration) + # update the attribute objects + for att in attributes: + att.symbolic[JascoCondConst.CRD.value] = SymbolicCondition(frame_chords=torch.tensor(frame_chords)) + return attributes + + @torch.no_grad() + def _prepare_drums_conditions(self, + attributes: + tp.List[ConditioningAttributes], + drums_wav: tp.Optional[torch.Tensor], + ): + # prepare drums cond + for attr in attributes: + if drums_wav is None: + attr.wav[JascoCondConst.DRM.value] = WavCondition( + torch.zeros((1, 1, 1), device=self.device), + torch.tensor([0], device=self.device), + sample_rate=[self.sample_rate], + path=[None]) + else: + if JascoCondConst.DRM.value not in self.lm.condition_provider.conditioners: + raise RuntimeError("This model doesn't support drums conditioning. ") + + expected_length = self.lm.cfg.dataset.segment_duration * self.sample_rate + # trim if needed + drums_wav = drums_wav[..., :expected_length] + + # pad if needed + if drums_wav.shape[-1] < expected_length: + diff = expected_length - drums_wav.shape[-1] + diff_zeros = torch.zeros((drums_wav.shape[0], drums_wav.shape[1], diff), + device=drums_wav.device, dtype=drums_wav.dtype) + drums_wav = torch.cat((drums_wav, diff_zeros), dim=-1) + + attr.wav[JascoCondConst.DRM.value] = WavCondition( + drums_wav.to(device=self.device), + torch.tensor([drums_wav.shape[-1]], device=self.device), + sample_rate=[self.sample_rate], + path=[None], + ) + + return attributes + + @torch.no_grad() + def _prepare_melody_conditions( + self, + attributes: tp.List[ConditioningAttributes], + melody: tp.Optional[torch.Tensor], + expected_length: int, + melody_bins: int = 53, + ) -> tp.List[ConditioningAttributes]: + """ + Prepares melody conditions by subtituting with pre-computed salience matrix. + This method updates the ConditioningAttributes with per-frame chords information. + Args: + attributes (List[ConditioningAttributes]): + The initial attributes and optional tensor data. + chords (List[Tuple[str, float]]): + A list of tuples containing chord labels and their start times. + Returns: + List[ConditioningAttributes]: + The updated attributes with frame chords integrated, alongside the original optional tensor data. + """ + for attr in attributes: + if melody is None: + melody = torch.zeros((melody_bins, expected_length)) + attr.symbolic[JascoCondConst.MLD.value] = SymbolicCondition(melody=melody) + return attributes + + @torch.no_grad() + def _prepare_temporal_conditions( + self, + attributes: tp.List[ConditioningAttributes], + expected_length: int, + chords: tp.Optional[tp.List[tp.Tuple[str, float]]], + drums_wav: tp.Optional[torch.Tensor], + salience_matrix: tp.Optional[torch.Tensor], + melody_bins: int = 53, + ) -> tp.List[ConditioningAttributes]: + """ + Prepares temporal conditions (chords, drums). + Args: + attributes (List[ConditioningAttributes]): The initial attributes and optional tensor data. + expected_length (int): The expected number of generated frames. + chords (List[Tuple[str, float]]): A list of tuples containing chord labels and their start times. + drums_wav (List[Tuple[str, float]]): tensor of extracted drums wav. + salience_matrix (List[Tuple[str, float]]): melody matrix. + melody_bins (int): number of melody bins the model was trained with, only relevant if trained with melody. + Returns: + List[ConditioningAttributes]: + The updated attributes after processing chord conditions. + """ + attributes = self._prepare_chord_conditions(attributes=attributes, chords=chords) + attributes = self._prepare_drums_conditions(attributes=attributes, drums_wav=drums_wav) + attributes = self._prepare_melody_conditions(attributes=attributes, melody=salience_matrix, + expected_length=expected_length, melody_bins=melody_bins) + return attributes + + @torch.no_grad() + def generate_music( + self, descriptions: tp.List[str], + drums_wav: tp.Optional[torch.Tensor] = None, + drums_sample_rate: int = 32000, + chords: tp.Optional[tp.List[tp.Tuple[str, float]]] = None, + melody_salience_matrix: tp.Optional[torch.Tensor] = None, + iopaint_wav: tp.Optional[torch.Tensor] = None, + segment_duration: float = 10.0, + frame_rate: float = 50.0, + melody_bins: int = 53, + progress: bool = False, return_latents: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples conditioned on text and temporal conditions (chords, melody, drums). + + Args: + descriptions (list of str): A list of strings used as text conditioning. + chords (list of (str, float) tuples): Chord progression represented as chord, start time (sec), e.g.: + [("C", 0.0), ("F", 4.0), ("G", 6.0), ("C", 8.0)] + melody_salience_matrix (torch.Tensor, optional): melody saliency matrix. Default=None. + iopaint_wav (torch.Tensor, optional): in/out=painting waveform. Default=None. + segment_duration (float): the segment duration the model was trained on. Default=None. + frame_rate (float): the frame_rate model was trained on. Default=None. + melody_bins (int): number of melody bins the model was trained with, only relevant if trained with melody. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + + if drums_wav is not None: + if drums_wav.dim() == 2: + drums_wav = drums_wav[None] + assert drums_wav.dim() == 3, "drums wav should have a shape [B, C, T]." + drums_wav = convert_audio(drums_wav, drums_sample_rate, self.sample_rate, self.audio_channels) + + cond_attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, + prompt=None) + + # prepare temporal conds (symbolic / audio) + jasco_attributes = self._prepare_temporal_conditions(attributes=cond_attributes, + expected_length=int(segment_duration * frame_rate), + chords=chords, + drums_wav=drums_wav, + salience_matrix=melody_salience_matrix, + melody_bins=melody_bins) + assert prompt_tokens is None + tokens = self._generate_tokens(jasco_attributes, prompt_tokens, progress) + if return_latents: + return self.generate_audio(tokens), tokens + return self.generate_audio(tokens) + + @torch.no_grad() + def generate(self, descriptions: tp.List[str], progress: bool = False, return_latents: bool = False) \ + -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples conditioned on text. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + return self.generate_music(descriptions=descriptions, progress=progress, return_latents=return_latents) diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index 3c7dd069..af370ceb 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -155,6 +155,23 @@ def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_mod return model +def load_jasco_model(file_or_url_or_id: tp.Union[Path, str], + compression_model: CompressionModel, + device='cpu', cache_dir: tp.Optional[str] = None): + pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) + cfg = OmegaConf.create(pkg['xp.cfg']) + cfg.device = str(device) + if cfg.device == 'cpu': + cfg.dtype = 'float32' + else: + cfg.dtype = 'float16' + model = builders.get_jasco_model(cfg, compression_model) + model.load_state_dict(pkg['best_state']) + model.eval() + model.cfg = cfg + return model + + def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], filename: tp.Optional[str] = None, cache_dir: tp.Optional[str] = None): diff --git a/audiocraft/modules/conditioners.py b/audiocraft/modules/conditioners.py index 7bb5b497..76de9ab8 100644 --- a/audiocraft/modules/conditioners.py +++ b/audiocraft/modules/conditioners.py @@ -15,7 +15,6 @@ import re import typing as tp import warnings - import einops import flashy from num2words import num2words @@ -25,7 +24,7 @@ from torch import nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence - +from enum import Enum from .chroma import ChromaExtractor from .streaming import StreamingModule from .transformer import create_sin_embedding, StreamingTransformer @@ -44,6 +43,15 @@ ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask +class JascoCondConst(Enum): + DRM = 'self_wav' + CRD = 'chords' + MLD = 'melody' + SYM = {'chords', 'melody'} + LAT = {'self_wav'} + ALL = ['chords', 'self_wav', 'melody'] # order matters + + class WavCondition(tp.NamedTuple): wav: torch.Tensor length: torch.Tensor @@ -61,11 +69,17 @@ class JointEmbedCondition(tp.NamedTuple): seek_time: tp.List[tp.Optional[float]] = [] +class SymbolicCondition(tp.NamedTuple): + frame_chords: tp.Optional[torch.Tensor] = None + melody: tp.Optional[torch.Tensor] = None + + @dataclass class ConditioningAttributes: text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) wav: tp.Dict[str, WavCondition] = field(default_factory=dict) joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) + symbolic: tp.Dict[str, SymbolicCondition] = field(default_factory=dict) def __getitem__(self, item): return getattr(self, item) @@ -82,19 +96,25 @@ def wav_attributes(self): def joint_embed_attributes(self): return self.joint_embed.keys() + @property + def symbolic_attributes(self): + return self.symbolic.keys() + @property def attributes(self): return { "text": self.text_attributes, "wav": self.wav_attributes, "joint_embed": self.joint_embed_attributes, + "symbolic": self.symbolic_attributes, } def to_flat_dict(self): return { **{f"text.{k}": v for k, v in self.text.items()}, **{f"wav.{k}": v for k, v in self.wav.items()}, - **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} + **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}, + **{f"symbolic.{k}": v for k, v in self.symbolic.items()} } @classmethod @@ -178,6 +198,28 @@ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: ) +def nullify_chords(sym_cond: SymbolicCondition, null_chord_idx: int = 194) -> SymbolicCondition: + """Nullify the symbolic condition by setting all frame chords to a specified null chord index. + Args: + sym_cond (SymbolicCondition): The symbolic condition containing frame chords to be nullified. + null_chord_idx (int, optional): The index to use for nullifying the chords. Defaults to 194 (Chordino). + Returns: + SymbolicCondition: A new symbolic condition with all frame chords set to the null chord index. + """ + return SymbolicCondition(frame_chords=torch.ones_like(sym_cond.frame_chords) * null_chord_idx) # type: ignore + + +def nullify_melody(sym_cond: SymbolicCondition) -> SymbolicCondition: + """Nullify the symbolic condition by replacing the melody matrix with zeros matrix. + Args: + sym_cond (SymbolicCondition): The symbolic condition containing frame chords to be nullified. + null_chord_idx (int, optional): The index to use for nullifying the chords. Defaults to 194 (Chordino). + Returns: + SymbolicCondition: A new symbolic condition with all frame chords set to the null chord index. + """ + return SymbolicCondition(melody=torch.zeros_like(sym_cond.melody)) # type: ignore + + def _drop_description_condition(conditions: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: """Drop the text condition but keep the wav conditon on a list of ConditioningAttributes. This is useful to calculate l_style in the double classifier free guidance formula. @@ -314,7 +356,8 @@ def __init__(self, dim: int, output_dim: int): super().__init__() self.dim = dim self.output_dim = output_dim - self.output_proj = nn.Linear(dim, output_dim) + if self.output_dim > -1: # omit projection when output_dim <= 0 + self.output_proj = nn.Linear(dim, output_dim) def tokenize(self, *args, **kwargs) -> tp.Any: """Should be any part of the processing that will lead to a synchronization @@ -512,8 +555,9 @@ def forward(self, x: WavCondition) -> ConditionType: wav, lengths, *_ = x with torch.no_grad(): embeds = self._get_wav_embedding(x) - embeds = embeds.to(self.output_proj.weight) - embeds = self.output_proj(embeds) + if hasattr(self, 'output_proj'): + embeds = embeds.to(self.output_proj.weight) + embeds = self.output_proj(embeds) if lengths is not None and self._use_masking: lengths = lengths / self._downsampling_factor() @@ -1257,13 +1301,48 @@ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Ten return embed, empty_idx -def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: +def dropout_symbolic_conditions(sample: ConditioningAttributes, + condition: str, null_chord_idx: int = 194) -> ConditioningAttributes: + """ + Applies dropout to symbolic conditions within the sample based on the specified condition by setting the condition + value to a null index. + Args: + sample (ConditioningAttributes): The sample containing symbolic attributes to potentially dropout. + condition (str): The specific condition within the symbolic attributes to apply dropout. + null_chord_idx (int, optional): The index used to represent a null chord. Defaults to 194. + Returns: + ConditioningAttributes: The modified sample with dropout applied to the specified condition. + Raises: + ValueError: If the specified condition is not present in the sample's symbolic attributes. + """ + if sample.symbolic == {} or sample.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1: # type: ignore + # nothing to drop + return sample + + if condition not in getattr(sample, 'symbolic'): + raise ValueError( + "dropout_symbolic_condition received an unexpected condition!" + f" expected {sample.symbolic.keys()}" + f" but got '{condition}'!" + ) + + if condition == JascoCondConst.CRD.value: + sample.symbolic[condition] = nullify_chords(sample.symbolic[condition], null_chord_idx=null_chord_idx) + elif condition == JascoCondConst.MLD.value: + sample.symbolic[condition] = nullify_melody(sample.symbolic[condition]) + + return sample + + +def dropout_condition(sample: ConditioningAttributes, + condition_type: str, condition: str, + **kwargs) -> ConditioningAttributes: """Utility function for nullifying an attribute inside an ConditioningAttributes object. If the condition is of type "wav", then nullify it using `nullify_condition` function. If the condition is of any other type, set its value to None. Works in-place. """ - if condition_type not in ['text', 'wav', 'joint_embed']: + if condition_type not in ['text', 'wav', 'joint_embed', 'symbolic']: raise ValueError( "dropout_condition got an unexpected condition type!" f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" @@ -1282,6 +1361,8 @@ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condi elif condition_type == 'joint_embed': embed = sample.joint_embed[condition] sample.joint_embed[condition] = nullify_joint_embed(embed) + elif condition_type == 'symbolic': + sample = dropout_symbolic_conditions(sample=sample, condition=condition, **kwargs) else: sample.text[condition] = None @@ -1332,7 +1413,7 @@ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[Condition return samples samples = deepcopy(samples) - for condition_type, ps in self.p.items(): # for condition types [text, wav] + for condition_type, ps in self.p.items(): # for condition types [text, wav, symbolic] for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) if torch.rand(1, generator=self.rng).item() < p: for sample in samples: @@ -1355,7 +1436,9 @@ def __init__(self, p: float, seed: int = 1234): super().__init__(seed=seed) self.p = p - def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + def forward(self, samples: tp.List[ConditioningAttributes], + cond_types: tp.List[str] = ["wav", "text"], + **kwargs) -> tp.List[ConditioningAttributes]: """ Args: samples (list[ConditioningAttributes]): List of conditions. @@ -1372,10 +1455,11 @@ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[Condition # nullify conditions of all attributes samples = deepcopy(samples) - for condition_type in ["wav", "text"]: + for condition_type in cond_types: for sample in samples: for condition in sample.attributes[condition_type]: - dropout_condition(sample, condition_type, condition) + dropout_condition(sample, condition_type, condition, + **kwargs) return samples def __repr__(self): @@ -1600,7 +1684,7 @@ class ConditionFuser(StreamingModule): cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. """ - FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"] + FUSING_METHODS = ["sum", "prepend", "cross", "ignore", "input_interpolate"] def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, cross_attention_pos_emb_scale: float = 1.0): @@ -1660,6 +1744,8 @@ def forward( cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) else: cross_attention_output = cond + elif op == 'ignore': + continue else: raise ValueError(f"unknown op ({op})") diff --git a/audiocraft/modules/jasco_conditioners.py b/audiocraft/modules/jasco_conditioners.py new file mode 100644 index 00000000..ee6b088c --- /dev/null +++ b/audiocraft/modules/jasco_conditioners.py @@ -0,0 +1,298 @@ +import torch +import typing as tp +from itertools import chain +from pathlib import Path +from torch import nn +from .conditioners import (ConditioningAttributes, BaseConditioner, ConditionType, + ConditioningProvider, JascoCondConst, + WaveformConditioner, WavCondition, SymbolicCondition) +from ..data.audio import audio_read +from ..data.audio_utils import convert_audio +from ..utils.autocast import TorchAutocast +from ..utils.cache import EmbeddingCache + + +class MelodyConditioner(BaseConditioner): + """ + A conditioner that handles melody conditioning from pre-computed salience matrix. + Attributes: + card (int): The cardinality of the melody matrix. + out_dim (int): The dimensionality of the output projection. + device (Union[torch.device, str]): The device on which the embeddings are stored. + """ + def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs): + super().__init__(dim=card, output_dim=out_dim) + self.device = device + + def tokenize(self, x: SymbolicCondition) -> SymbolicCondition: + return SymbolicCondition(melody=x.melody.to(self.device)) # type: ignore + + def forward(self, x: SymbolicCondition) -> ConditionType: + embeds = self.output_proj(x.melody.permute(0, 2, 1)) # type: ignore + mask = torch.ones_like(embeds[..., 0]) + return embeds, mask + + +class ChordsEmbConditioner(BaseConditioner): + """ + A conditioner that embeds chord symbols into a continuous vector space. + Attributes: + card (int): The cardinality of the chord vocabulary. + out_dim (int): The dimensionality of the output embeddings. + device (Union[torch.device, str]): The device on which the embeddings are stored. + """ + def __init__(self, card: int, out_dim: int, device: tp.Union[torch.device, str] = 'cpu', **kwargs): + vocab_size = card + 1 # card + 1 - for null chord used during dropout + super().__init__(dim=vocab_size, output_dim=-1) # out_dim=-1 to avoid another projection + self.emb = nn.Embedding(vocab_size, out_dim, device=device) + self.device = device + + def tokenize(self, x: SymbolicCondition) -> SymbolicCondition: + return SymbolicCondition(frame_chords=x.frame_chords.to(self.device)) # type: ignore + + def forward(self, x: SymbolicCondition) -> ConditionType: + embeds = self.emb(x.frame_chords) + mask = torch.ones_like(embeds[..., 0]) + return embeds, mask + + +class DrumsConditioner(WaveformConditioner): + def __init__(self, out_dim: int, sample_rate: int, blurring_factor: int = 3, + cache_path: tp.Optional[tp.Union[str, Path]] = None, + compression_model_latent_dim: int = 128, + compression_model_framerate: float = 50, + segment_duration: float = 10.0, + device: tp.Union[torch.device, str] = 'cpu', + **kwargs): + """Drum condition conditioner + + Args: + out_dim (int): _description_ + sample_rate (int): _description_ + blurring_factor (int, optional): _description_. Defaults to 3. + cache_path (tp.Optional[tp.Union[str, Path]], optional): path to precomputed cache. Defaults to None. + compression_model_latent_dim (int, optional): latent dimensino. Defaults to 128. + compression_model_framerate (float, optional): frame rate of the representation model. Defaults to 50. + segment_duration (float, optional): duration in sec for each audio segment. Defaults to 10.0. + device (tp.Union[torch.device, str], optional): device. Defaults to 'cpu'. + """ + from demucs import pretrained + self.sample_rate = sample_rate + self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) + stem_sources: list = self.demucs.sources # type: ignore + self.stem_idx = stem_sources.index('drums') + self.compression_model = None + self.latent_dim = compression_model_latent_dim + super().__init__(dim=self.latent_dim, output_dim=out_dim, device=device) + self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) + self._use_masking = False + self.blurring_factor = blurring_factor + self.seq_len = int(segment_duration * compression_model_framerate) + self.cache = None + if cache_path is not None: + self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, + compute_embed_fn=self._calc_coarse_drum_codes_for_cache, + extract_embed_fn=self._load_drum_codes_chunk) + + @torch.no_grad() + def _get_drums_stem(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + """Get parts of the wav that holds the drums, extracting the main stems from the wav.""" + from demucs.apply import apply_model + from demucs.audio import convert_audio + with self.autocast: + wav = convert_audio( + wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore + stems = apply_model(self.demucs, wav, device=self.device) + drum_stem = stems[:, self.stem_idx] # extract relevant stems for drums conditioning + return convert_audio(drum_stem, self.demucs.samplerate, self.sample_rate, 1) # type: ignore + + def _temporal_blur(self, z: torch.Tensor): + # z: (B, T, C) + B, T, C = z.shape + if T % self.blurring_factor != 0: + # pad with reflect for T % self.temporal_blurring on the right in dim=1 + pad_val = self.blurring_factor - T % self.blurring_factor + z = torch.nn.functional.pad(z, (0, 0, 0, pad_val), mode='reflect') + z = z.reshape(B, -1, self.blurring_factor, C).sum(dim=2) / self.blurring_factor + z = z.unsqueeze(2).repeat(1, 1, self.blurring_factor, 1).reshape(B, -1, C) + z = z[:, :T] + assert z.shape == (B, T, C) + return z + + @torch.no_grad() + def _extract_coarse_drum_codes(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: + assert self.compression_model is not None + + # stem separation of drums + drums = self._get_drums_stem(wav, sample_rate) + + # continuous encoding with compression model + latents = self.compression_model.model.encoder(drums) + + # quantization to coarsest codebook + coarsest_quantizer = self.compression_model.model.quantizer.layers[0] + drums = coarsest_quantizer.encode(latents).to(torch.int16) + return drums + + @torch.no_grad() + def _calc_coarse_drum_codes_for_cache(self, path: tp.Union[str, Path], + x: WavCondition, idx: int, + max_duration_to_process: float = 600) -> torch.Tensor: + """Extract blurred drum latents from the whole audio waveform at the given path.""" + wav, sr = audio_read(path) + wav = wav[None].to(self.device) + wav = convert_audio(wav, sr, self.sample_rate, to_channels=1) + + max_frames_to_process = int(max_duration_to_process * self.sample_rate) + if wav.shape[-1] > max_frames_to_process: + # process very long tracks in chunks + start = 0 + codes = [] + while start < wav.shape[-1] - 1: + wav_chunk = wav[..., start: start + max_frames_to_process] + codes.append(self._extract_coarse_drum_codes(wav_chunk, self.sample_rate)[0]) + start += max_frames_to_process + return torch.cat(codes) + + return self._extract_coarse_drum_codes(wav, self.sample_rate)[0] + + def _load_drum_codes_chunk(self, full_coarse_drum_codes: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor: + """Extract a chunk of coarse drum codes from the full coarse drum codes derived from the full waveform.""" + wav_length = x.wav.shape[-1] + seek_time = x.seek_time[idx] + assert seek_time is not None, ( + "WavCondition seek_time is required " + "when extracting chunks from pre-computed drum codes.") + assert self.compression_model is not None + frame_rate = self.compression_model.frame_rate + target_length = int(frame_rate * wav_length / self.sample_rate) + target_length = max(target_length, self.seq_len) + index = int(frame_rate * seek_time) + out = full_coarse_drum_codes[index: index + target_length] + # pad + out = torch.cat((out, torch.zeros(target_length - out.shape[0], dtype=out.dtype, device=out.device))) + return out.to(self.device) + + @torch.no_grad() + def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: + bs = x.wav.shape[0] + if x.wav.shape[-1] <= 1: + # null condition + return torch.zeros((bs, self.seq_len, self.latent_dim), device=x.wav.device, dtype=x.wav.dtype) + + # extract coarse drum codes + no_undefined_paths = all(p is not None for p in x.path) + no_nullified_cond = x.wav.shape[-1] > 1 + if self.cache is not None and no_undefined_paths and no_nullified_cond: + paths = [Path(p) for p in x.path if p is not None] + codes = self.cache.get_embed_from_cache(paths, x) + else: + assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal." + codes = self._extract_coarse_drum_codes(x.wav, x.sample_rate[0]) + + assert self.compression_model is not None + # decode back to the continuous representation of compression model + codes = codes.unsqueeze(1).permute(1, 0, 2) # (B, T) -> (1, B, T) + codes = codes.to(torch.int64) + latents = self.compression_model.model.quantizer.decode(codes) + + latents = latents.permute(0, 2, 1) # [B, C, T] -> [B, T, C] + + # temporal blurring + return self._temporal_blur(latents) + + def tokenize(self, x: WavCondition) -> WavCondition: + """Apply WavConditioner tokenization and populate cache if needed.""" + x = super().tokenize(x) + no_undefined_paths = all(p is not None for p in x.path) + if self.cache is not None and no_undefined_paths: + paths = [Path(p) for p in x.path if p is not None] + self.cache.populate_embed_cache(paths, x) + return x + + +class JascoConditioningProvider(ConditioningProvider): + """ + A cond-provider that manages and tokenizes various types of conditioning attributes for Jasco models. + Attributes: + chords_card (int): The cardinality of the chord vocabulary. + sequence_length (int): The length of the sequence for padding purposes. + melody_dim (int): The dimensionality of the melody matrix. + """ + def __init__(self, *args, + chords_card: int = 194, + sequence_length: int = 500, + melody_dim: int = 53, **kwargs): + self.null_chord = chords_card + self.sequence_len = sequence_length + self.melody_dim = melody_dim + super().__init__(*args, **kwargs) + + def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: + """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. + This should be called before starting any real GPU work to avoid synchronization points. + This will return a dict matching conditioner names to their arbitrary tokenized representations. + + Args: + inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing + text and wav conditions. + """ + assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( + "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", + f" but types were {set([type(x) for x in inputs])}" + ) + + output = {} + text = self._collate_text(inputs) + wavs = self._collate_wavs(inputs) + + symbolic = self._collate_symbolic(inputs, self.conditioners.keys()) + + assert set(text.keys() | wavs.keys() | symbolic.keys()).issubset(set(self.conditioners.keys())), ( + f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", + f"got {text.keys(), wavs.keys(), symbolic.keys()}" + ) + + for attribute, batch in chain(text.items(), wavs.items(), symbolic.items()): + output[attribute] = self.conditioners[attribute].tokenize(batch) + return output + + def _collate_symbolic(self, samples: tp.List[ConditioningAttributes], + conditioner_keys: tp.Set) -> tp.Dict[str, SymbolicCondition]: + output = {} + + # collate if symbolic cond exists + if any(x in conditioner_keys for x in JascoCondConst.SYM.value): + + for s in samples: + # hydrate with null chord if chords not exist - for inference support + if (s.symbolic == {} or + s.symbolic[JascoCondConst.CRD.value].frame_chords is None or + s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] <= 1): # type: ignore + # no chords conditioning - fill with null chord token + s.symbolic[JascoCondConst.CRD.value] = SymbolicCondition( + frame_chords=torch.ones(self.sequence_len, dtype=torch.int32) * self.null_chord) + + if (s.symbolic == {} or + s.symbolic[JascoCondConst.MLD.value].melody is None or + s.symbolic[JascoCondConst.MLD.value].melody.shape[-1] <= 1): # type: ignore + # no chords conditioning - fill with null chord token + s.symbolic[JascoCondConst.MLD.value] = SymbolicCondition( + melody=torch.zeros((self.melody_dim, self.sequence_len))) + + if JascoCondConst.CRD.value in conditioner_keys: + # pad to max + max_seq_len = max( + [s.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1] for s in samples]) # type: ignore + padded_chords = [ + torch.cat((x.symbolic[JascoCondConst.CRD.value].frame_chords, # type: ignore + torch.ones(max_seq_len - + x.symbolic[JascoCondConst.CRD.value].frame_chords.shape[-1], # type: ignore + dtype=torch.int32) * self.null_chord)) + for x in samples + ] + output[JascoCondConst.CRD.value] = SymbolicCondition(frame_chords=torch.stack(padded_chords)) + if JascoCondConst.MLD.value in conditioner_keys: + melodies = torch.stack([x.symbolic[JascoCondConst.MLD.value].melody for x in samples]) # type: ignore + output[JascoCondConst.MLD.value] = SymbolicCondition(melody=melodies) + return output diff --git a/audiocraft/modules/unet_transformer.py b/audiocraft/modules/unet_transformer.py new file mode 100644 index 00000000..53fe1f85 --- /dev/null +++ b/audiocraft/modules/unet_transformer.py @@ -0,0 +1,67 @@ +import torch +import typing as tp +from .transformer import StreamingTransformer, create_sin_embedding + + +class UnetTransformer(StreamingTransformer): + """U-net Transformer for processing sequences with optional skip connections. + This transformer architecture incorporates U-net style skip connections + between layers, which can be optionally enabled. It inherits from a + StreamingTransformer. + + Args: + d_model (int): Dimension of the model, typically the number of expected features in the input. + num_layers (int): Total number of layers in the transformer. + skip_connections (bool, optional): Flag to determine whether skip connections should be used. + Defaults to False. + layer_dropout_p (float, Optional): if given, defined bernoulli prob. to drop a skip connection (in training). + **kwargs: Additional keyword arguments inherited from `nn.StreamingTransformer`. + """ + def __init__(self, d_model: int, num_layers: int, skip_connections: bool = False, + layer_dropout_p: tp.Optional[float] = None, **kwargs): + super().__init__(d_model=d_model, + num_layers=num_layers, + **kwargs) + self.skip_connect = skip_connections + if self.skip_connect: + self.skip_projections = torch.nn.ModuleList([torch.nn.Linear(d_model * 2, d_model) + for _ in range(num_layers // 2)]) + self.num_layers = num_layers + self.layer_drop_p = max(min(layer_dropout_p, 1.), 0.) if layer_dropout_p is not None else 0.0 + + def forward(self, x: torch.Tensor, *args, **kwargs): + B, T, C = x.shape + + if 'offsets' in self._streaming_state: + offsets = self._streaming_state['offsets'] + else: + offsets = torch.zeros(B, dtype=torch.long, device=x.device) + + if self.positional_embedding in ['sin', 'sin_rope']: + positions = torch.arange(T, device=x.device).view(1, -1, 1) + positions = positions + offsets.view(-1, 1, 1) + pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) + x = x + self.positional_scale * pos_emb + + skip_connections: tp.List[torch.Tensor] = [] + + for i, layer in enumerate(self.layers): + if self.skip_connect and i >= self.num_layers // 2: + + # in the second half of the layers, add residual connection + # and linearly project the concatenated features back to d_model + x = torch.cat([x, skip_connections.pop()], dim=-1) + x = self.skip_projections[i % len(self.skip_projections)](x) + + x = self._apply_layer(layer, x, *args, **kwargs) + + if self.skip_connect and i < self.num_layers // 2: + if self.training and torch.rand(1,) < self.layer_drop_p: # drop skip + skip_connections.append(torch.zeros_like(x)) + else: + skip_connections.append(x) + + if self._is_streaming: + self._streaming_state['offsets'] = offsets + T + + return x diff --git a/audiocraft/solvers/builders.py b/audiocraft/solvers/builders.py index bf18f2d6..e39993a8 100644 --- a/audiocraft/solvers/builders.py +++ b/audiocraft/solvers/builders.py @@ -38,6 +38,7 @@ class DatasetType(Enum): AUDIO = "audio" MUSIC = "music" SOUND = "sound" + JASCO = "jasco" def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: @@ -48,6 +49,7 @@ def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: from .diffusion import DiffusionSolver from .magnet import MagnetSolver, AudioMagnetSolver from .watermark import WatermarkSolver + from .jasco import JascoSolver klass = { 'compression': CompressionSolver, 'musicgen': MusicGenSolver, @@ -58,6 +60,7 @@ def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: 'diffusion': DiffusionSolver, 'sound_lm': AudioGenSolver, # backward compatibility 'watermarking': WatermarkSolver, + 'jasco': JascoSolver, }[cfg.solver] return klass(cfg) # type: ignore @@ -355,6 +358,8 @@ def get_audio_datasets(cfg: omegaconf.DictConfig, dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs) elif dataset_type == DatasetType.AUDIO: dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs) + elif dataset_type == DatasetType.JASCO: + dataset = data.jasco_dataset.JascoDataset.from_meta(path, return_info=return_info, **kwargs) else: raise ValueError(f"Dataset type is unsupported: {dataset_type}") diff --git a/audiocraft/solvers/jasco.py b/audiocraft/solvers/jasco.py new file mode 100644 index 00000000..8b6d36e4 --- /dev/null +++ b/audiocraft/solvers/jasco.py @@ -0,0 +1,287 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from omegaconf import DictConfig +from . import builders, musicgen +from .compression import CompressionSolver +from .. import models +from ..modules.conditioners import JascoCondConst, SegmentWithAttributes +import torch +import typing as tp +import flashy +import time +import math + + +class JascoSolver(musicgen.MusicGenSolver): + """Solver for JASCO - Joint Audio and Symbolic Conditioning for Temporally Controlled Text-to-Music Generation + https://arxiv.org/abs/2406.10970. + """ + DATASET_TYPE: builders.DatasetType = builders.DatasetType.JASCO + + def __init__(self, cfg: DictConfig): + super().__init__(cfg) + + # initialize generation parameters by config + self.generation_params = { + 'cfg_coef_all': self.cfg.generate.lm.cfg_coef_all, + 'cfg_coef_txt': self.cfg.generate.lm.cfg_coef_txt + } + + self.latent_mean = cfg.compression_model_latent_mean + self.latent_std = cfg.compression_model_latent_std + self.mse = torch.nn.MSELoss(reduction='none') + self._best_metric_name = 'loss' + + def build_model(self) -> None: + """Instantiate model and optimization.""" + assert self.cfg.efficient_attention_backend == "xformers", "JASCO v1 models support only xformers backend." + + self.compression_model = CompressionSolver.wrapped_model_from_checkpoint( + self.cfg, self.cfg.compression_model_checkpoint, device=self.device) + assert self.compression_model.sample_rate == self.cfg.sample_rate, ( + f"Compression model sample rate is {self.compression_model.sample_rate} but " + f"Solver sample rate is {self.cfg.sample_rate}." + ) + # instantiate JASCO model + self.model: models.FlowMatchingModel = models.builders.get_jasco_model(self.cfg, + self.compression_model).to(self.device) + # initialize optimization + self.initialize_optimization() + + def _get_latents(self, audio): + with torch.no_grad(): + latents = self.compression_model.model.encoder(audio) + return latents.permute(0, 2, 1) # [B, D, T] -> [B, T, D] + + def _prepare_latents_and_attributes( + self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], + ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]: + """Prepare input batchs for language model training. + + Args: + batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T] + and corresponding metadata as SegmentWithAttributes (with B items). + Returns: + Condition tensors (dict[str, any]): Preprocessed condition attributes. + Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s], + with B the batch size, K the number of codebooks, T_s the token timesteps. + Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s]. + """ + audio, infos = batch + audio = audio.to(self.device) + assert audio.size(0) == len(infos), ( + f"Mismatch between number of items in audio batch ({audio.size(0)})", + f" and in metadata ({len(infos)})" + ) + + latents = self._get_latents(audio) + + # prepare attributes + if JascoCondConst.CRD.value in self.cfg.conditioners: + null_chord_idx = self.cfg.conditioners.chords.chords_emb.card + else: + null_chord_idx = -1 + attributes = [info.to_condition_attributes() for info in infos] + if self.model.cfg_dropout is not None: + attributes = self.model.cfg_dropout(samples=attributes, + cond_types=["wav", "text", "symbolic"], + null_chord_idx=null_chord_idx) + attributes = self.model.att_dropout(attributes) + tokenized = self.model.condition_provider.tokenize(attributes) + + with self.autocast: + condition_tensors = self.model.condition_provider(tokenized) + + # create a padding mask to hold valid vs invalid positions + padding_mask = torch.ones_like(latents, dtype=torch.bool, device=latents.device) + + return condition_tensors, latents, padding_mask + + def _normalized_latents(self, latents: torch.Tensor) -> torch.Tensor: + """Normalize latents.""" + return (latents - self.latent_mean) / self.latent_std + + def _unnormalized_latents(self, latents: torch.Tensor) -> torch.Tensor: + """Unnormalize latents.""" + return (latents * self.latent_std) + self.latent_mean + + def _z(self, z_0: torch.Tensor, z_1: torch.Tensor, t: torch.Tensor, sigma_min: float = 1e-5) -> torch.Tensor: + """Interpolate data and prior.""" + return (1 - (1 - sigma_min) * t) * z_0 + t * z_1 + + def _vector_field(self, z_0: torch.Tensor, z_1: torch.Tensor, sigma_min: float = 1e-5) -> torch.Tensor: + """Compute the GT vector field. + sigma_min is a small value to avoid numerical instabilities.""" + return z_1 - (1 - sigma_min) * z_0 + + def _compute_loss(self, t: torch.Tensor, v_theta: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Compute the loss.""" + loss_func = self.cfg.get('loss_func', 'increasing') + if loss_func == 'uniform': + scales = 1 + elif loss_func == 'increasing': + scales = 1 + t # type: ignore + elif loss_func == 'decreasing': + scales = 2 - t # type: ignore + else: + raise ValueError('unsupported loss_func was passed in config') + return (scales * self.mse(v_theta, v)).mean() + + def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict: + """Perform one training or valid step on a given batch.""" + + condition_tensors, latents, padding_mask = self._prepare_latents_and_attributes(batch) + + self.deadlock_detect.update('tokens_and_conditions') + + B, T, D = latents.shape + device = self.device + + # normalize latents + z_1 = self._normalized_latents(latents) + + # sample the N(0,1) prior + z_0 = torch.randn(B, T, D, device=device) + + # random time parameter, between 0 to 1 + t = torch.rand((B, 1, 1), device=device) + + # interpolate data and prior + z = self._z(z_0, z_1, t) + + # compute the GT vector field + v = self._vector_field(z_0, z_1) + + with self.autocast: + v_theta = self.model(latents=z, + t=t, + conditions=[], + condition_tensors=condition_tensors) + + loss = self._compute_loss(t, v_theta, v) + unscaled_loss = loss.clone() + + self.deadlock_detect.update('loss') + + if self.is_training: + metrics['lr'] = self.optimizer.param_groups[0]['lr'] + if self.scaler is not None: + loss = self.scaler.scale(loss) + self.deadlock_detect.update('scale') + if self.cfg.fsdp.use: + loss.backward() + flashy.distrib.average_tensors(self.model.buffers()) + elif self.cfg.optim.eager_sync: + with flashy.distrib.eager_sync_model(self.model): + loss.backward() + else: + # this should always be slower but can be useful + # for weird use cases like multiple backwards. + loss.backward() + flashy.distrib.sync_model(self.model) + self.deadlock_detect.update('backward') + + if self.scaler is not None: + self.scaler.unscale_(self.optimizer) + if self.cfg.optim.max_norm: + if self.cfg.fsdp.use: + metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore + else: + metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.optim.max_norm + ) + if self.scaler is None: + self.optimizer.step() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + if self.lr_scheduler: + self.lr_scheduler.step() + self.optimizer.zero_grad() + self.deadlock_detect.update('optim') + if self.scaler is not None: + scale = self.scaler.get_scale() + metrics['grad_scale'] = scale + if not loss.isfinite().all(): + raise RuntimeError("Model probably diverged.") + + metrics['loss'] = unscaled_loss + + return metrics + + def _decode_latents(self, latents): + return self.compression_model.model.decoder(latents.permute(0, 2, 1)) + + @torch.no_grad() + def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], + gen_duration: float, prompt_duration: tp.Optional[float] = None, + remove_text_conditioning: bool = False, + **generation_params) -> dict: + """Run generate step on a batch of optional audio tensor and corresponding attributes. + + Args: + batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): + use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch. + gen_duration (float): Target audio duration for the generation. + prompt_duration (float, optional): Duration for the audio prompt to use for continuation. + remove_text_conditioning (bool, optional): Whether to remove the prompt from the generated audio. + generation_params: Additional generation parameters. + Returns: + gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation + and the prompt along with additional information. + """ + bench_start = time.time() + audio, meta = batch + assert audio.size(0) == len(meta), ( + f"Mismatch between number of items in audio batch ({audio.size(0)})", + f" and in metadata ({len(meta)})" + ) + # prepare attributes + attributes = [x.to_condition_attributes() for x in meta] + + # prepare audio prompt + if prompt_duration is None: + prompt_audio = None + else: + assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration" + prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate) + prompt_audio = audio[..., :prompt_audio_frames] + + # get audio tokens from compression model + if prompt_audio is None or prompt_audio.nelement() == 0: + num_samples = len(attributes) + prompt_tokens = None + else: + num_samples = None + prompt_audio = prompt_audio.to(self.device) + prompt_tokens, scale = self.compression_model.encode(prompt_audio) + assert scale is None, "Compression model in MusicGen should not require rescaling." + + # generate by sampling from the LM + with self.autocast: + total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate) + gen_latents = self.model.generate( + prompt_tokens, attributes, max_gen_len=total_gen_len, + num_samples=num_samples, **self.generation_params) + + # generate audio from latents + assert gen_latents.dim() == 3 # [B, T, D] + + # unnormalize latents + gen_latents = self._unnormalized_latents(gen_latents) + gen_audio = self._decode_latents(gen_latents) + + bench_end = time.time() + gen_outputs = { + 'rtf': (bench_end - bench_start) / gen_duration, + 'ref_audio': audio, + 'gen_audio': gen_audio, + 'gen_tokens': gen_latents, + 'prompt_audio': prompt_audio, + 'prompt_tokens': prompt_tokens, + } + return gen_outputs diff --git a/audiocraft/solvers/musicgen.py b/audiocraft/solvers/musicgen.py index 8a5c2c03..8b48d00b 100644 --- a/audiocraft/solvers/musicgen.py +++ b/audiocraft/solvers/musicgen.py @@ -111,6 +111,32 @@ def get_formatter(self, stage_name: str) -> flashy.Formatter: def best_metric_name(self) -> tp.Optional[str]: return self._best_metric_name + def initialize_optimization(self) -> None: + if self.cfg.fsdp.use: + assert not self.cfg.autocast, "Cannot use autocast with fsdp" + self.model = self.wrap_with_fsdp(self.model) + self.register_ema('model') + # initialize optimization + self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) + self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) + self.register_stateful('model', 'optimizer', 'lr_scheduler') + self.register_best_state('model') + self.autocast_dtype = { + 'float16': torch.float16, 'bfloat16': torch.bfloat16 + }[self.cfg.autocast_dtype] + self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None + if self.cfg.fsdp.use: + need_scaler = self.cfg.fsdp.param_dtype == 'float16' + else: + need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16 + if need_scaler: + if self.cfg.fsdp.use: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + self.scaler = ShardedGradScaler() # type: ignore + else: + self.scaler = torch.cuda.amp.GradScaler() + self.register_stateful('scaler') + def build_model(self) -> None: """Instantiate models and optimizer.""" # we can potentially not use all quantizers with which the EnCodec model was trained @@ -136,31 +162,11 @@ def build_model(self) -> None: self.compression_model.num_codebooks, self.compression_model.cardinality, self.compression_model.frame_rate) # instantiate LM model - self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device) - if self.cfg.fsdp.use: - assert not self.cfg.autocast, "Cannot use autocast with fsdp" - self.model = self.wrap_with_fsdp(self.model) - self.register_ema('model') + self.model: tp.Union[models.LMModel, models.FlowMatchingModel] = models.builders.get_lm_model( + self.cfg).to(self.device) + # initialize optimization - self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) - self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) - self.register_stateful('model', 'optimizer', 'lr_scheduler') - self.register_best_state('model') - self.autocast_dtype = { - 'float16': torch.float16, 'bfloat16': torch.bfloat16 - }[self.cfg.autocast_dtype] - self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None - if self.cfg.fsdp.use: - need_scaler = self.cfg.fsdp.param_dtype == 'float16' - else: - need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16 - if need_scaler: - if self.cfg.fsdp.use: - from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler - self.scaler = ShardedGradScaler() # type: ignore - else: - self.scaler = torch.cuda.amp.GradScaler() - self.register_stateful('scaler') + self.initialize_optimization() def build_dataloaders(self) -> None: """Instantiate audio dataloaders for each stage.""" @@ -244,6 +250,12 @@ def _compute_cross_entropy( ce = ce / K return ce, ce_per_codebook + def _get_audio_tokens(self, audio: torch.Tensor): + with torch.no_grad(): + audio_tokens, scale = self.compression_model.encode(audio) + assert scale is None, "Scaled compression model not supported with LM." + return audio_tokens + def _prepare_tokens_and_attributes( self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], check_synchronization_points: bool = False @@ -310,9 +322,7 @@ def _prepare_tokens_and_attributes( torch.cuda.set_sync_debug_mode("warn") if audio_tokens is None: - with torch.no_grad(): - audio_tokens, scale = self.compression_model.encode(audio) - assert scale is None, "Scaled compression model not supported with LM." + audio_tokens = self._get_audio_tokens(audio) with self.autocast: condition_tensors = self.model.condition_provider(tokenized) @@ -679,7 +689,7 @@ def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor: gen_outputs = self.run_generate_step( batch, gen_duration=target_duration, - remove_text_conditioning=self.cfg.evaluate.remove_text_conditioning + remove_text_conditioning=self.cfg.evaluate.get('remove_text_conditioning', False) ) y_pred = gen_outputs['gen_audio'].detach() y_pred = y_pred[..., :audio.shape[-1]] diff --git a/audiocraft/utils/utils.py b/audiocraft/utils/utils.py index 2c5799f8..bc9d9f39 100644 --- a/audiocraft/utils/utils.py +++ b/audiocraft/utils/utils.py @@ -12,7 +12,6 @@ import logging from pathlib import Path import typing as tp - import flashy import flashy.distrib import omegaconf @@ -296,3 +295,32 @@ def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): pkg = load_state_dict(path) pkg.pop('text_branch.embeddings.position_ids', None) clap_model.model.load_state_dict(pkg) + + +def construct_frame_chords( + min_timestamp: int, + chord_changes: tp.List[tp.Tuple[float, str]], + mapping_dict: tp.Dict, + prev_chord: str, + frame_rate: float, + segment_duration: float, + ) -> tp.List[str]: + """ Translate symbolic chords [(start_time, tuples),...] into a frame-level int sequence""" + + frames = [ + frame / frame_rate + for frame in range( + min_timestamp, int(min_timestamp + segment_duration * frame_rate) + ) + ] + + frame_chords = [] + current_chord = prev_chord + + for frame in frames: + while chord_changes and frame >= chord_changes[0][0]: + current_chord = chord_changes.pop(0)[1] + current_chord = 'N' if current_chord in {None, ''} else current_chord + frame_chords.append(mapping_dict[current_chord]) + + return frame_chords diff --git a/config/conditioner/chords2music.yaml b/config/conditioner/chords2music.yaml new file mode 100644 index 00000000..33c6b564 --- /dev/null +++ b/config/conditioner/chords2music.yaml @@ -0,0 +1,38 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.3 # dropout of all conditions + inference_coef: 3.0 + +attribute_dropout: + symbolic: + chords: 0.3 # independent dropout of chords + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + ignore: [chords] + input_interpolate: [] + +conditioners: + description: + model: t5 + t5: + name: t5-base + finetune: false + word_dropout: 0.3 + normalize_text: false + chords: + model: chords_emb + chords_emb: + card: 194 # Chordino + out_dim: 16 + +dataset: + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/conditioner/drums2music.yaml b/config/conditioner/drums2music.yaml new file mode 100644 index 00000000..dfeea215 --- /dev/null +++ b/config/conditioner/drums2music.yaml @@ -0,0 +1,42 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.3 # dropout of all conditions + inference_coef: 3.0 + +attribute_dropout: + text: {} + wav: + self_wav: 0.3 # independent dropout of drums + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + ignore: [self_wav] + input_interpolate: [] + +conditioners: + self_wav: + model: drum_latents + drum_latents: + sample_rate: ${sample_rate} + out_dim: 2 + blurring_factor: 3 + cache_path: null + + description: + model: t5 + t5: + name: t5-base + finetune: false + word_dropout: 0.3 + normalize_text: false + +dataset: + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/conditioner/jasco_chords_drums.yaml b/config/conditioner/jasco_chords_drums.yaml new file mode 100644 index 00000000..4417361c --- /dev/null +++ b/config/conditioner/jasco_chords_drums.yaml @@ -0,0 +1,50 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.3 # dropout of all conditions + inference_coef: 3.0 + +attribute_dropout: + text: {} + symbolic: + chords: 0.3 # independent dropout of chords + wav: + self_wav: 0.3 # independent dropout of drums + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + ignore: [chords, self_wav] + input_interpolate: [] + +conditioners: + self_wav: + model: drum_latents + drum_latents: + sample_rate: ${sample_rate} + out_dim: 2 + blurring_factor: 3 + cache_path: null + + description: + model: t5 + t5: + name: t5-base + finetune: false + word_dropout: 0.3 + normalize_text: false + chords: + model: chords_emb + chords_emb: + card: 194 # Chordino + out_dim: 16 + +dataset: + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 + diff --git a/config/conditioner/jasco_chords_drums_melody.yaml b/config/conditioner/jasco_chords_drums_melody.yaml new file mode 100644 index 00000000..08a2157a --- /dev/null +++ b/config/conditioner/jasco_chords_drums_melody.yaml @@ -0,0 +1,60 @@ +# @package __global__ + +classifier_free_guidance: + training_dropout: 0.2 # dropout of all conditions + inference_coef: 3.0 + +attribute_dropout: + text: + description: 0.0 + symbolic: + chords: 0.5 # independent dropout of chords + melody: 0.5 # independent dropout of melody + wav: + self_wav: 0.5 # independent dropout of drums + + +fuser: + cross_attention_pos_emb: false + cross_attention_pos_emb_scale: 1 + sum: [] + prepend: [] + cross: [description] + ignore: [chords, self_wav, melody] + input_interpolate: [] + +conditioners: + self_wav: + model: drum_latents + drum_latents: + sample_rate: ${sample_rate} + out_dim: 2 + blurring_factor: 3 + cache_path: ??? + read_only_cache: true + + description: + model: t5 + t5: + name: t5-base + finetune: false + word_dropout: 0.3 + normalize_text: false + + chords: + model: chords_emb + chords_emb: + card: 194 # Chordino + out_dim: 16 + + melody: + model: melody + melody: + card: 53 # Preprocessed salience dim + out_dim: 16 + +dataset: + train: + merge_text_p: 0.25 + drop_desc_p: 0.5 + drop_other_p: 0.5 diff --git a/config/solver/jasco/chords.yaml b/config/solver/jasco/chords.yaml new file mode 100644 index 00000000..d9c671c1 --- /dev/null +++ b/config/solver/jasco/chords.yaml @@ -0,0 +1,81 @@ +# @package __global__ + +# This is the training loop solver +# for the base MusicGen model (text-to-music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /dset: audio/example + - override /conditioner: chords2music + - _self_ + +lm_model: flow_matching +solver: jasco + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz +# precomputed mean std accross training data sub-partition +compression_model_latent_std: 4.0102 +compression_model_latent_mean: -0.0074 +compression_model_framerate: 50 +compression_model_latent_dim: 128 + +efficient_attention_backend: xformers # restricted attention implementation supports only xformers at the moment + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + segment_duration: 10 + batch_size: 320 # 32 GPUs + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + chords_card: ${conditioners.chords.chords_emb.card} + compression_model_framerate: ${compression_model_framerate} + +optim: + epochs: 500 + optimizer: adamw + lr: 1e-4 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 + +transformer_lm: + causal: false + skip_connections: false + flow_dim: ${compression_model_latent_dim} + chords_dim: ${conditioners.chords.chords_emb.out_dim} + +generate: + lm: + max_prompt_len: null + max_gen_len: null + remove_prompts: false + cfg_coef_all: 3.0 + cfg_coef_txt: 1.0 + prompted_samples: false + samples: + prompted: false + unprompted: true diff --git a/config/solver/jasco/chords_drums.yaml b/config/solver/jasco/chords_drums.yaml new file mode 100644 index 00000000..50ac8097 --- /dev/null +++ b/config/solver/jasco/chords_drums.yaml @@ -0,0 +1,88 @@ +# @package __global__ + +# This is the training loop solver +# for the base MusicGen model (text-to-music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /conditioner: jasco_chords_drums + - override /dset: audio/default + - _self_ + +lm_model: flow_matching +solver: jasco + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz +# precomputed mean std accross training data sub-partition +compression_model_latent_std: 4.0102 +compression_model_latent_mean: -0.0074 +compression_model_framerate: 50 +compression_model_latent_dim: 128 + +efficient_attention_backend: xformers # restricted attention implementation supports only xformers at the moment + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + segment_duration: 10 + batch_size: 336 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + compression_model_framerate: ${compression_model_framerate} + +optim: + epochs: 500 + optimizer: adamw + lr: 1e-4 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 + +transformer_lm: + causal: false + skip_connections: false + flow_dim: ${compression_model_latent_dim} + chords_dim: ${conditioners.chords.chords_emb.out_dim} + drums_dim: ${conditioners.self_wav.drum_latents.out_dim} + +generate: + lm: + max_prompt_len: null + max_gen_len: null + remove_prompts: false + cfg_coef_all: 5.0 + cfg_coef_txt: 0.0 + prompted_samples: false + samples: + prompted: false + unprompted: true + +conditioners: + self_wav: + drum_latents: + compression_model_latent_dim: ${compression_model_latent_dim} + compression_model_framerate: ${compression_model_framerate} + segment_duration: ${dataset.segment_duration} diff --git a/config/solver/jasco/chords_drums_melody.yaml b/config/solver/jasco/chords_drums_melody.yaml new file mode 100644 index 00000000..c3a1a135 --- /dev/null +++ b/config/solver/jasco/chords_drums_melody.yaml @@ -0,0 +1,97 @@ +# @package __global__ + +# This is the training loop solver +# for the base MusicGen model (text-to-music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /dset: audio/default + - override /conditioner: jasco_chords_drums_melody + - _self_ + +lm_model: flow_matching +solver: jasco + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz +# precomputed mean std accross training data sub-partition +compression_model_latent_std: 4.0102 +compression_model_latent_mean: -0.0074 +compression_model_framerate: 50 +compression_model_latent_dim: 128 + +efficient_attention_backend: xformers # restricted attention implementation supports only xformers at the moment + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + segment_duration: 10 + batch_size: 336 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + compression_model_framerate: ${compression_model_framerate} + melody_kwargs: + chroma_root: ??? # path to parsed chroma files + segment_duration: ${dataset.segment_duration} + melody_fr: 86 + latent_fr: ${compression_model_framerate} + melody_salience_dim: 53 + override_cache: false + do_argmax: true + +optim: + epochs: 500 + optimizer: adamw + lr: 1e-4 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 + +transformer_lm: + causal: false + skip_connections: false + flow_dim: ${compression_model_latent_dim} + chords_dim: ${conditioners.chords.chords_emb.out_dim} + drums_dim: ${conditioners.self_wav.drum_latents.out_dim} + melody_dim: ${conditioners.melody.melody.out_dim} + +generate: + lm: + max_prompt_len: null + max_gen_len: null + remove_prompts: false + cfg_coef_all: 3.0 + cfg_coef_txt: 1.0 + prompted_samples: false + samples: + prompted: false + unprompted: true + +conditioners: + self_wav: + drum_latents: + compression_model_latent_dim: ${compression_model_latent_dim} + compression_model_framerate: ${compression_model_framerate} + segment_duration: ${dataset.segment_duration} diff --git a/config/solver/jasco/drums.yaml b/config/solver/jasco/drums.yaml new file mode 100644 index 00000000..bcaf1ebf --- /dev/null +++ b/config/solver/jasco/drums.yaml @@ -0,0 +1,87 @@ +# @package __global__ + +# This is the training loop solver +# for the base MusicGen model (text-to-music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /dset: audio/default + - override /conditioner: drums2music + - _self_ + +lm_model: flow_matching +solver: jasco + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz +# precomputed mean std accross training data sub-partition +compression_model_latent_std: 4.0102 +compression_model_latent_mean: -0.0074 +compression_model_framerate: 50 +compression_model_latent_dim: 128 + +efficient_attention_backend: xformers # restricted attention implementation supports only xformers at the moment + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + segment_duration: 10 + batch_size: 336 + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + compression_model_framerate: ${compression_model_framerate} + +optim: + epochs: 500 + optimizer: adamw + lr: 1e-4 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 + +transformer_lm: + causal: false + skip_connections: false + flow_dim: ${compression_model_latent_dim} + drums_dim: ${conditioners.self_wav.drum_latents.out_dim} + +generate: + lm: + max_prompt_len: null + max_gen_len: null + remove_prompts: false + cfg_coef_all: 3.0 + cfg_coef_txt: 1.0 + prompted_samples: false + samples: + prompted: false + unprompted: true + +conditioners: + self_wav: + drum_latents: + compression_model_latent_dim: ${compression_model_latent_dim} + compression_model_framerate: ${compression_model_framerate} + segment_duration: ${dataset.segment_duration} diff --git a/config/solver/jasco/jasco_32khz_base.yaml b/config/solver/jasco/jasco_32khz_base.yaml new file mode 100644 index 00000000..5b0aa942 --- /dev/null +++ b/config/solver/jasco/jasco_32khz_base.yaml @@ -0,0 +1,78 @@ +# @package __global__ + +# This is the training loop solver +# for the base MusicGen model (text-to-music) +# on monophonic audio sampled at 32 kHz +defaults: + - musicgen/default + - /model: lm/musicgen_lm + - override /dset: audio/default + - _self_ + +lm_model: flow_matching +solver: jasco + +autocast: true +autocast_dtype: float16 + +# EnCodec large trained on mono-channel music audio sampled at 32khz +# with a total stride of 640 leading to 50 frames/s. +# rvq.n_q=4, rvq.bins=2048, no quantization dropout +# (transformer_lm card and n_q must be compatible) +compression_model_checkpoint: //pretrained/facebook/encodec_32khz +# precomputed mean std accross training data sub-partition +compression_model_latent_std: 4.0102 +compression_model_latent_mean: -0.0074 +compression_model_framerate: 50 +compression_model_latent_dim: 128 + +efficient_attention_backend: xformers # restricted attention implementation supports only xformers at the moment + +channels: 1 +sample_rate: 32000 + +deadlock: + use: true # deadlock detection + +dataset: + segment_duration: 10 + batch_size: 320 # 32 GPUs + sample_on_weight: false # Uniform sampling all the way + sample_on_duration: false # Uniform sampling all the way + +optim: + epochs: 500 + optimizer: adamw + lr: 1e-4 + ema: + use: true + updates: 10 + device: cuda + +logging: + log_tensorboard: true + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 + +transformer_lm: + causal: false + skip_connections: false + flow_dim: ${compression_model_latent_dim} + +generate: + lm: + max_prompt_len: null + max_gen_len: null + remove_prompts: false + cfg_coef_all: 3.0 + cfg_coef_txt: 0.0 + + prompted_samples: false + samples: + prompted: false + unprompted: true diff --git a/demos/jasco_app.py b/demos/jasco_app.py new file mode 100644 index 00000000..18e0009f --- /dev/null +++ b/demos/jasco_app.py @@ -0,0 +1,364 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under thmage license found in the +# LICENSE file in the root directory of this source tree. +import argparse +from concurrent.futures import ProcessPoolExecutor +import logging +import os +from pathlib import Path +import subprocess as sp +import sys +from tempfile import NamedTemporaryFile +import time +import typing as tp +import torch +import gradio as gr # type: ignore +from audiocraft.data.audio_utils import f32_pcm, normalize_audio +from audiocraft.data.audio import audio_write +from audiocraft.models import JASCO +# flake8: noqa + +MODEL = None # Last used model +SPACE_ID = os.environ.get('SPACE_ID', '') +MAX_BATCH_SIZE = 12 +INTERRUPTING = False +MBD = None +# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform +_old_call = sp.call + + +def _call_nostderr(*args, **kwargs): + # Avoid ffmpeg vomiting on the logs. + kwargs['stderr'] = sp.DEVNULL + kwargs['stdout'] = sp.DEVNULL + _old_call(*args, **kwargs) + + +sp.call = _call_nostderr +# Preallocating the pool of processes. +pool = ProcessPoolExecutor(4) +pool.__enter__() + + +def interrupt(): + global INTERRUPTING + INTERRUPTING = True + + +class FileCleaner: + def __init__(self, file_lifetime: float = 3600): + self.file_lifetime = file_lifetime + self.files = [] # type: ignore + + def add(self, path: tp.Union[str, Path]): + self._cleanup() + self.files.append((time.time(), Path(path))) + + def _cleanup(self): + now = time.time() + for time_added, path in list(self.files): + if now - time_added > self.file_lifetime: + if path.exists(): + path.unlink() + self.files.pop(0) + else: + break + + +file_cleaner = FileCleaner() + + +def chords_string_to_list(chords: str): + if chords == '': + return [] + + # clean white spaces or [ ] chars + chords = chords.replace('[', '') + chords = chords.replace(']', '') + chords = chords.replace(' ', '') + chrd_times = [x.split(',') for x in chords[1:-1].split('),(')] + return [(x[0], float(x[1])) for x in chrd_times] + + +def load_model(version='facebook/jasco-chords-drums-400M'): + global MODEL + print("Loading model", version) + if MODEL is None or MODEL.name != version: + MODEL = None # in case loading would crash + MODEL = JASCO.get_pretrained(version) + + +def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=False, gradio_progress=None, **gen_kwargs): + MODEL.set_generation_params(**gen_kwargs) + be = time.time() + + # preprocess chords: str to list of tuples + chords = chords_string_to_list(chords) + + if melody_matrix is not None: + melody_matrix = torch.load(melody_matrix.name, weights_only=True) + if len(melody_matrix.shape) != 2: + raise gr.Error(f"Melody matrix should be a torch tensor of shape [n_melody_bins, T]; got: {melody_matrix.shape}") + if melody_matrix.shape[0] > melody_matrix.shape[1]: + melody_matrix = melody_matrix.permute(1, 0) + + # preprocess drums + if drum_prompt is None: + preprocessed_drums_wav = None + drums_sr = 32000 + else: + # gradio loads audio in int PCM 16-bit, we need to convert it to float32 + drums_sr, drums = drum_prompt[0], f32_pcm(torch.from_numpy(drum_prompt[1])).t() + if drums.dim() == 1: + drums = drums[None] + + drums = normalize_audio(drums, strategy="loudness", loudness_headroom_db=16, sample_rate=drums_sr) + preprocessed_drums_wav = drums + try: + outputs = MODEL.generate_music(descriptions=texts, chords=chords, + drums_wav=preprocessed_drums_wav, + melody_salience_matrix=melody_matrix, + drums_sample_rate=drums_sr, progress=progress) + except RuntimeError as e: + raise gr.Error("Error while generating " + e.args[0]) + outputs = outputs.detach().cpu().float() + out_wavs = [] + for output in outputs: + with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: + audio_write( + file.name, output, MODEL.sample_rate, strategy="loudness", + loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) + out_wavs.append(file.name) + file_cleaner.add(file.name) + print("batch finished", len(texts), time.time() - be) + print("Tempfiles currently stored: ", len(file_cleaner.files)) + return out_wavs + + +def predict_full(model, + text, chords_sym, melody_file, + drums_file, drums_mic, drum_input_src, + cfg_coef_all, cfg_coef_txt, + ode_rtol, ode_atol, + ode_solver, ode_steps, + progress=gr.Progress()): + global INTERRUPTING + INTERRUPTING = False + progress(0, desc="Loading model...") + load_model(model) + + max_generated = 0 + + def _progress(generated, to_generate): + nonlocal max_generated + max_generated = max(generated, max_generated) + progress((min(max_generated, to_generate), to_generate)) + if INTERRUPTING: + raise gr.Error("Interrupted.") + MODEL.set_custom_progress_callback(_progress) + + drums = drums_mic if drum_input_src == "mic" else drums_file + wavs = _do_predictions( + texts=[text] * 2, # we generate two audio outputs for each input prompt + chords=chords_sym, + drum_prompt=drums, + melody_matrix=melody_file, + progress=True, + gradio_progress=progress, + cfg_coef_all=cfg_coef_all, + cfg_coef_txt=cfg_coef_txt, + ode_rtol=ode_rtol, + ode_atol=ode_atol, + euler=ode_solver == 'euler', + euler_steps=ode_steps) + + return wavs + + +def ui_full(launch_kwargs): + with gr.Blocks() as interface: + gr.Markdown( + """ + # JASCO + This is your private demo for [JASCO](https://github.com/facebookresearch/audiocraft), + A text-to-music model, with temporal control over melodies, chords or beats. + + presented at: ["Joint Audio and Symbolic Conditioning for Temporally Controlled Text-to-Music Generation"] + (https://arxiv.org/abs/2406.10970) + """ + ) + # Submit | generated + with gr.Row(): + with gr.Column(): + submit = gr.Button("Submit") + # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. + _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) + + with gr.Column(): + audio_output_0 = gr.Audio(label="Generated Audio", type='filepath') + audio_output_1 = gr.Audio(label="Generated Audio", type='filepath') + + # TEXT | models + with gr.Row(): + with gr.Column(): + text = gr.Text(label="Input Text", + value="Strings, woodwind, orchestral, symphony.", + interactive=True) + with gr.Column(): + model = gr.Radio([ + 'facebook/jasco-chords-drums-400M', 'facebook/jasco-chords-drums-1B', + 'facebook/jasco-chords-drums-melody-400M', 'facebook/jasco-chords-drums-melody-1B', + ], + label="Model", value='facebook/jasco-chords-drums-melody-400M', interactive=True) + + # CHORDS + gr.Markdown("Chords conditions") + with gr.Row(): + chords_sym = gr.Text(label="Chord Progression", + value="(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", + interactive=True) + + # DRUMS + gr.Markdown("Drums conditions") + with gr.Row(): + drum_input_src = gr.Radio(["file", "mic"], value="file", + label="Condition on drums (optional) File or Mic") + drums_file = gr.Audio(sources=["upload"], type="numpy", label="File", + interactive=True, elem_id="drums-input") + + drums_mic = gr.Audio(sources=["microphone"], type="numpy", label="Mic", + interactive=True, elem_id="drums-mic-input") + + # MELODY + gr.Markdown("Melody conditions") + with gr.Row(): + melody_file = gr.File(label="Melody File", interactive=True, elem_id="melody-file-input") + + # CFG params + gr.Markdown("Classifier-Free Guidance (CFG) Coefficients:") + with gr.Row(): + cfg_coef_all = gr.Number(label="ALL", value=1.25, step=0.25, interactive=True) + cfg_coef_txt = gr.Number(label="TEXT", value=2.5, step=0.25, interactive=True) + ode_tol = gr.Number(label="ODE solver tolerance (defines error approx stop threshold for dynammic solver)", + value=1e-4, step=1e-5, interactive=True) + ode_solver = gr.Radio([ + 'euler', 'dopri5' + ], + label="ODE Solver", value='euler', interactive=True) + ode_steps = gr.Number(label="Steps (for euler solver)", value=10, step=1, interactive=True) + + submit.click(fn=predict_full, + inputs=[model, + text, chords_sym, melody_file, + drums_file, drums_mic, drum_input_src, + cfg_coef_all, cfg_coef_txt, ode_tol, ode_tol, ode_solver, ode_steps], + outputs=[audio_output_0, audio_output_1]) + gr.Examples( + fn=predict_full, + examples=[ + [ + "80s pop with groovy synth bass and electric piano", + "(N, 0.0), (C, 0.32), (Dm7, 3.456), (Am, 4.608), (F, 8.32), (C, 9.216)", + "./assets/salience_2.th", + "./assets/salience_2.wav", + ], + [ + "Strings, woodwind, orchestral, symphony.", # text + "(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", # chords + None, # melody + None, # drums + ], + [ + "distortion guitars, heavy rock, catchy beat", + "", + None, + "./assets/sep_drums_1.mp3", + ], + [ + "hip hop beat with a catchy melody and a groovy bass line", + "", + None, + "./assets/CJ_Beatbox_Loop_05_90.wav", + ], + [ + "hip hop beat with a catchy melody and a groovy bass line", + "(C, 0.0), (D, 2.0), (F, 4.0), (Ab, 6.0), (Bb, 7.0), (C, 8.0)", + None, + "./assets/CJ_Beatbox_Loop_05_90.wav", + ], + + ], + inputs=[text, chords_sym, melody_file, drums_file], + outputs=[audio_output_0, audio_output_1] + ) + gr.Markdown( + """ + ### More details + + "JASCO" model will generate a 10 seconds of music based on textual descriptions together with + temporal controls such as chords and drum tracks. + These models were trained with descriptions from a stock music catalog. Descriptions that will work best + should include some level of details on the instruments present, along with some intended use case + (e.g. adding "perfect for a commercial" can somehow help). + + We present 4 model variants: + 1. facebook/jasco-chords-drums-400M - 10s music generation conditioned on text, chords and drums,400M parameters. + 2. facebook/jasco-chords-drums-1B - 10s music generation conditioned on text, chords and drums, 1B parameters. + 3. facebook/jasco-chords-drums-melody-400M - 10s music generation conditioned on text, chords, drums and melody,400M parameters. + 4. facebook/jasco-chords-drums-melody-1B - 10s music generation conditioned on text, chords, drums and melody, 1B parameters. + + See https://github.com/facebookresearch/audiocraft/blob/main/docs/JASCO.md + for more details. + """ + ) + + interface.queue().launch(**launch_kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--listen', + type=str, + default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1', + help='IP to listen on for connections to Gradio', + ) + parser.add_argument( + '--username', type=str, default='', help='Username for authentication' + ) + parser.add_argument( + '--password', type=str, default='', help='Password for authentication' + ) + parser.add_argument( + '--server_port', + type=int, + default=0, + help='Port to run the server listener on', + ) + parser.add_argument( + '--inbrowser', action='store_true', help='Open in browser' + ) + parser.add_argument( + '--share', action='store_true', help='Share the gradio UI' + ) + + args = parser.parse_args() + + launch_kwargs = {} + launch_kwargs['server_name'] = args.listen + + if args.username and args.password: + launch_kwargs['auth'] = (args.username, args.password) + if args.server_port: + launch_kwargs['server_port'] = args.server_port + if args.inbrowser: + launch_kwargs['inbrowser'] = args.inbrowser + if args.share: + launch_kwargs['share'] = args.share + + logging.basicConfig(level=logging.INFO, stream=sys.stderr) + + # Show the interface + ui_full(launch_kwargs) diff --git a/demos/jasco_demo.ipynb b/demos/jasco_demo.ipynb new file mode 100644 index 00000000..6f0afbd3 --- /dev/null +++ b/demos/jasco_demo.ipynb @@ -0,0 +1,397 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# JASCO\n", + "Welcome to JASCO's demo jupyter notebook. \n", + "Here you will find a self-contained example of how to use JASCO for temporally controlled music generation.\n", + "\n", + "You can choose a model from the following selection:\n", + "1. facebook/jasco-chords-drums-400M - 10s music generation conditioned on text, chords and drums, 400M parameters\n", + "2. facebook/jasco-chords-drums-1B - 10s music generation conditioned on text, chords and drums, 1B parameters\n", + "3. facebook/jasco-chords-drums-melody-400M - 10s music generation conditioned on text, chords, drums and melody, 400M parameters\n", + "4. facebook/jasco-chords-drums-melody-1B - 10s music generation conditioned on text, chords, drums and melody, 1B parameters\n", + "\n", + "First, we start by initializing the JASCO model:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m \n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01maudiocraft\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodels\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m JASCO\n\u001b[1;32m 4\u001b[0m chords_mapping_path \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mabspath(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m../../assets/chord_to_index_mapping.pkl\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 5\u001b[0m model \u001b[38;5;241m=\u001b[39m JASCO\u001b[38;5;241m.\u001b[39mget_pretrained(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mfacebook/jasco-chords-drums-melody-400M\u001b[39m\u001b[38;5;124m'\u001b[39m, chords_mapping_path\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m../assets/chord_to_index_mapping.pkl\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/audiocraft/__init__.py:24\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124;03mAudioCraft is a general framework for training audio generative models.\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124;03mAt the moment we provide the training code for:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;124;03m improves the perceived quality and reduces the artifacts coming from adversarial decoders.\u001b[39;00m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# flake8: noqa\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m data, modules, models\n\u001b[1;32m 26\u001b[0m __version__ \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m1.3.0\u001b[39m\u001b[38;5;124m'\u001b[39m\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/audiocraft/data/__init__.py:10\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124;03m\"\"\"Audio loading and writing support. Datasets for raw audio\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124;03mor also including some metadata.\"\"\"\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m# flake8: noqa\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/audiocraft/data/info_audio_dataset.py:19\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01maudio_dataset\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AudioDataset, AudioMeta\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01menvironment\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m AudioCraftEnvironment\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmodules\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mconditioners\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SegmentWithAttributes, ConditioningAttributes\n\u001b[1;32m 22\u001b[0m logger \u001b[38;5;241m=\u001b[39m logging\u001b[38;5;241m.\u001b[39mgetLogger(\u001b[38;5;18m__name__\u001b[39m)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_clusterify_meta\u001b[39m(meta: AudioMeta) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m AudioMeta:\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/audiocraft/modules/__init__.py:22\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlstm\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m StreamableLSTM\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mseanet\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SEANetEncoder, SEANetDecoder\n\u001b[0;32m---> 22\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtransformer\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m StreamingTransformer\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/audiocraft/modules/transformer.py:23\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m functional \u001b[38;5;28;01mas\u001b[39;00m F\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcheckpoint\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m checkpoint \u001b[38;5;28;01mas\u001b[39;00m torch_checkpoint\n\u001b[0;32m---> 23\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mxformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ops\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrope\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m RotaryEmbedding\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mstreaming\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m StreamingModule\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/xformers/__init__.py:12\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _cpp_lib\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcheckpoint\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ( \u001b[38;5;66;03m# noqa: E402, F401\u001b[39;00m\n\u001b[1;32m 13\u001b[0m checkpoint,\n\u001b[1;32m 14\u001b[0m get_optimal_checkpoint_policy,\n\u001b[1;32m 15\u001b[0m list_operators,\n\u001b[1;32m 16\u001b[0m selective_checkpoint_wrapper,\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mversion\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m __version__ \u001b[38;5;66;03m# noqa: F401\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/xformers/checkpoint.py:475\u001b[0m\n\u001b[1;32m 471\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcounter \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_output[count] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m--> 475\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mSelectiveCheckpointWrapper\u001b[39;00m(ActivationWrapper):\n\u001b[1;32m 476\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, mod, memory_budget\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, policy_fn\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m__version__ \u001b[38;5;241m<\u001b[39m (\u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m):\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/xformers/checkpoint.py:496\u001b[0m, in \u001b[0;36mSelectiveCheckpointWrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m:\n\u001b[1;32m 493\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[1;32m 495\u001b[0m \u001b[43m\u001b[49m\u001b[38;5;129;43m@torch\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompiler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdisable\u001b[49m\n\u001b[0;32m--> 496\u001b[0m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mdef\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43m_get_policy_fn\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_grad_enabled\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 498\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# no need to compute a policy as it won't be used\u001b[39;49;00m\n\u001b[1;32m 499\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mreturn\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/compiler/__init__.py:152\u001b[0m, in \u001b[0;36mdisable\u001b[0;34m(fn, recursive)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdisable\u001b[39m(fn\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, recursive\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 144\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;124;03m This function provides both a decorator and a context manager to disable compilation on a function\u001b[39;00m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;124;03m It also provides the option of recursively disabling called functions\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;124;03m recursive (optional): A boolean value indicating whether the disabling should be recursive.\u001b[39;00m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 152\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_dynamo\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mdisable(fn, recursive)\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/_dynamo/__init__.py:2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m convert_frame, eval_frame, resume_execution\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackends\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mregistry\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m list_backends, lookup_backend, register_backend\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcallback\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m callback_handler, on_compile_end, on_compile_start\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:48\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_python_dispatch\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _disable_current_modes\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_traceback\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m format_traceback_short\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m config, exc, trace_rules\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackends\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mregistry\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m CompilerFn\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbytecode_analysis\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m remove_dead_code, remove_pointless_jumps\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/_dynamo/trace_rules.py:52\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mresume_execution\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TORCH_DYNAMO_RESUME_IN_PREFIX\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper\n\u001b[0;32m---> 52\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvariables\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 53\u001b[0m BuiltinVariable,\n\u001b[1;32m 54\u001b[0m FunctorchHigherOrderVariable,\n\u001b[1;32m 55\u001b[0m NestedUserFunctionVariable,\n\u001b[1;32m 56\u001b[0m SkipFunctionVariable,\n\u001b[1;32m 57\u001b[0m TorchInGraphFunctionVariable,\n\u001b[1;32m 58\u001b[0m UserFunctionVariable,\n\u001b[1;32m 59\u001b[0m UserMethodVariable,\n\u001b[1;32m 60\u001b[0m )\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m typing\u001b[38;5;241m.\u001b[39mTYPE_CHECKING:\n\u001b[1;32m 64\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvariables\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbase\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m VariableTracker\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/_dynamo/variables/__init__.py:38\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdistributed\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m BackwardHookVariable, DistributedVariable, PlacementVariable\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfunctions\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 32\u001b[0m FunctoolsPartialVariable,\n\u001b[1;32m 33\u001b[0m NestedUserFunctionVariable,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 36\u001b[0m UserMethodVariable,\n\u001b[1;32m 37\u001b[0m )\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mhigher_order_ops\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 39\u001b[0m FunctorchHigherOrderVariable,\n\u001b[1;32m 40\u001b[0m TorchHigherOrderOperatorVariable,\n\u001b[1;32m 41\u001b[0m )\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01miter\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 43\u001b[0m CountIteratorVariable,\n\u001b[1;32m 44\u001b[0m CycleIteratorVariable,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 47\u001b[0m RepeatIteratorVariable,\n\u001b[1;32m 48\u001b[0m )\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlazy\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m LazyVariableTracker\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/_dynamo/variables/higher_order_ops.py:14\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfx\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moperators\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_dynamo\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_fake_value\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_dynamo\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mvariables\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ConstantVariable\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/__init__.py:11\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_C\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _onnx \u001b[38;5;28;01mas\u001b[39;00m _C_onnx\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_C\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_onnx\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 5\u001b[0m _CAFFE2_ATEN_FALLBACK,\n\u001b[1;32m 6\u001b[0m OperatorExportTypes,\n\u001b[1;32m 7\u001b[0m TensorProtoDataType,\n\u001b[1;32m 8\u001b[0m TrainingMode,\n\u001b[1;32m 9\u001b[0m )\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ( \u001b[38;5;66;03m# usort:skip. Keep the order instead of sorting lexicographically\u001b[39;00m\n\u001b[1;32m 12\u001b[0m _deprecation,\n\u001b[1;32m 13\u001b[0m errors,\n\u001b[1;32m 14\u001b[0m symbolic_caffe2,\n\u001b[1;32m 15\u001b[0m symbolic_helper,\n\u001b[1;32m 16\u001b[0m symbolic_opset7,\n\u001b[1;32m 17\u001b[0m symbolic_opset8,\n\u001b[1;32m 18\u001b[0m symbolic_opset9,\n\u001b[1;32m 19\u001b[0m symbolic_opset10,\n\u001b[1;32m 20\u001b[0m symbolic_opset11,\n\u001b[1;32m 21\u001b[0m symbolic_opset12,\n\u001b[1;32m 22\u001b[0m symbolic_opset13,\n\u001b[1;32m 23\u001b[0m symbolic_opset14,\n\u001b[1;32m 24\u001b[0m symbolic_opset15,\n\u001b[1;32m 25\u001b[0m symbolic_opset16,\n\u001b[1;32m 26\u001b[0m symbolic_opset17,\n\u001b[1;32m 27\u001b[0m symbolic_opset18,\n\u001b[1;32m 28\u001b[0m symbolic_opset19,\n\u001b[1;32m 29\u001b[0m symbolic_opset20,\n\u001b[1;32m 30\u001b[0m utils,\n\u001b[1;32m 31\u001b[0m )\n\u001b[1;32m 33\u001b[0m \u001b[38;5;66;03m# TODO(After 1.13 release): Remove the deprecated SymbolicContext\u001b[39;00m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_exporter_states\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ExportTypes, SymbolicContext\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/errors.py:9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _C\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _constants\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m diagnostics\n\u001b[1;32m 11\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 12\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOnnxExporterError\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 13\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOnnxExporterWarning\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnsupportedOperatorError\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 17\u001b[0m ]\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mOnnxExporterWarning\u001b[39;00m(\u001b[38;5;167;01mUserWarning\u001b[39;00m):\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_diagnostic\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 2\u001b[0m create_export_diagnostic_context,\n\u001b[1;32m 3\u001b[0m diagnose,\n\u001b[1;32m 4\u001b[0m engine,\n\u001b[1;32m 5\u001b[0m export_context,\n\u001b[1;32m 6\u001b[0m ExportDiagnosticEngine,\n\u001b[1;32m 7\u001b[0m TorchScriptOnnxExportDiagnostic,\n\u001b[1;32m 8\u001b[0m )\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_rules\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m rules\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m levels\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/_diagnostic.py:12\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List, Optional\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m infra\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m formatter, sarif\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msarif\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m version \u001b[38;5;28;01mas\u001b[39;00m sarif_version\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/infra/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_infra\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 2\u001b[0m DiagnosticOptions,\n\u001b[1;32m 3\u001b[0m Graph,\n\u001b[1;32m 4\u001b[0m Invocation,\n\u001b[1;32m 5\u001b[0m Level,\n\u001b[1;32m 6\u001b[0m levels,\n\u001b[1;32m 7\u001b[0m Location,\n\u001b[1;32m 8\u001b[0m Rule,\n\u001b[1;32m 9\u001b[0m RuleCollection,\n\u001b[1;32m 10\u001b[0m Stack,\n\u001b[1;32m 11\u001b[0m StackFrame,\n\u001b[1;32m 12\u001b[0m Tag,\n\u001b[1;32m 13\u001b[0m ThreadFlowLocation,\n\u001b[1;32m 14\u001b[0m )\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mcontext\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnostic\n\u001b[1;32m 17\u001b[0m __all__ \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 18\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDiagnostic\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 19\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDiagnosticContext\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThreadFlowLocation\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 33\u001b[0m ]\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/infra/_infra.py:11\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mlogging\u001b[39;00m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m FrozenSet, List, Mapping, Optional, Sequence, Tuple\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m formatter, sarif\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mLevel\u001b[39;00m(enum\u001b[38;5;241m.\u001b[39mIntEnum):\n\u001b[1;32m 15\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"The level of a diagnostic.\u001b[39;00m\n\u001b[1;32m 16\u001b[0m \n\u001b[1;32m 17\u001b[0m \u001b[38;5;124;03m This class is used to represent the level of a diagnostic. The levels are defined\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;124;03m Level.ERROR = logging.ERROR = 40\u001b[39;00m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/infra/formatter.py:11\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_logging\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m LazyString\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _beartype\n\u001b[0;32m---> 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m sarif\n\u001b[1;32m 14\u001b[0m \u001b[38;5;66;03m# A list of types in the SARIF module to support pretty printing.\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;66;03m# This is solely for type annotation for the functions below.\u001b[39;00m\n\u001b[1;32m 16\u001b[0m _SarifClass \u001b[38;5;241m=\u001b[39m Union[\n\u001b[1;32m 17\u001b[0m sarif\u001b[38;5;241m.\u001b[39mSarifLog,\n\u001b[1;32m 18\u001b[0m sarif\u001b[38;5;241m.\u001b[39mRun,\n\u001b[1;32m 19\u001b[0m sarif\u001b[38;5;241m.\u001b[39mReportingDescriptor,\n\u001b[1;32m 20\u001b[0m sarif\u001b[38;5;241m.\u001b[39mResult,\n\u001b[1;32m 21\u001b[0m ]\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/infra/sarif/__init__.py:71\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msarif\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_result\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Result\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msarif\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_result_provenance\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 69\u001b[0m ResultProvenance,\n\u001b[1;32m 70\u001b[0m )\n\u001b[0;32m---> 71\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msarif\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_run\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Run\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msarif\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_run_automation_details\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 73\u001b[0m RunAutomationDetails,\n\u001b[1;32m 74\u001b[0m )\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msarif\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_sarif_log\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SarifLog\n", + "File \u001b[0;32m~/miniconda3/envs/jasco_dev/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/infra/sarif/_run.py:9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mdataclasses\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Any, List, Literal, Optional\n\u001b[0;32m----> 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01monnx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_internal\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdiagnostics\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minfra\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msarif\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 10\u001b[0m _address,\n\u001b[1;32m 11\u001b[0m _artifact,\n\u001b[1;32m 12\u001b[0m _conversion,\n\u001b[1;32m 13\u001b[0m _external_property_file_references,\n\u001b[1;32m 14\u001b[0m _graph,\n\u001b[1;32m 15\u001b[0m _invocation,\n\u001b[1;32m 16\u001b[0m _logical_location,\n\u001b[1;32m 17\u001b[0m _property_bag,\n\u001b[1;32m 18\u001b[0m _result,\n\u001b[1;32m 19\u001b[0m _run_automation_details,\n\u001b[1;32m 20\u001b[0m _special_locations,\n\u001b[1;32m 21\u001b[0m _thread_flow_location,\n\u001b[1;32m 22\u001b[0m _tool,\n\u001b[1;32m 23\u001b[0m _tool_component,\n\u001b[1;32m 24\u001b[0m _version_control_details,\n\u001b[1;32m 25\u001b[0m _web_request,\n\u001b[1;32m 26\u001b[0m _web_response,\n\u001b[1;32m 27\u001b[0m )\n\u001b[1;32m 30\u001b[0m \u001b[38;5;129m@dataclasses\u001b[39m\u001b[38;5;241m.\u001b[39mdataclass\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mRun\u001b[39;00m(\u001b[38;5;28mobject\u001b[39m):\n\u001b[1;32m 32\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Describes a single run of an analysis tool, and contains the reported output of that run.\"\"\"\u001b[39;00m\n", + "File \u001b[0;32m:1007\u001b[0m, in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n", + "File \u001b[0;32m:982\u001b[0m, in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n", + "File \u001b[0;32m:925\u001b[0m, in \u001b[0;36m_find_spec\u001b[0;34m(name, path, target)\u001b[0m\n", + "File \u001b[0;32m:1423\u001b[0m, in \u001b[0;36mfind_spec\u001b[0;34m(cls, fullname, path, target)\u001b[0m\n", + "File \u001b[0;32m:1395\u001b[0m, in \u001b[0;36m_get_spec\u001b[0;34m(cls, fullname, path, target)\u001b[0m\n", + "File \u001b[0;32m:1555\u001b[0m, in \u001b[0;36mfind_spec\u001b[0;34m(self, fullname, target)\u001b[0m\n", + "File \u001b[0;32m:156\u001b[0m, in \u001b[0;36m_path_isfile\u001b[0;34m(path)\u001b[0m\n", + "File \u001b[0;32m:148\u001b[0m, in \u001b[0;36m_path_is_mode_type\u001b[0;34m(path, mode)\u001b[0m\n", + "File \u001b[0;32m:142\u001b[0m, in \u001b[0;36m_path_stat\u001b[0;34m(path)\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "import os \n", + "from audiocraft.models import JASCO\n", + "\n", + "chords_mapping_path = os.path.abspath('../../assets/chord_to_index_mapping.pkl')\n", + "model = JASCO.get_pretrained('facebook/jasco-chords-drums-melody-400M', chords_mapping_path='../assets/chord_to_index_mapping.pkl')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, let us configure the generation parameters. Specifically, you can control the following:\n", + "* `cfg_coef_all` (float, optional): Coefficient used for classifier free guidance - fully conditional term. \n", + " Defaults to 5.0.\n", + "* `cfg_coef_txt` (float, optional): Coefficient used for classifier free guidance - additional text conditional term. \n", + " Defaults to 0.0.\n", + "\n", + "When left unchanged, JASCO will revert to its default parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.set_generation_params(\n", + " cfg_coef_all=0.0,\n", + " cfg_coef_txt=5.0\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we can go ahead and start generating music given textual prompts." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Text-conditional Generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from audiocraft.utils.notebook import display_audio\n", + "\n", + "# set textual prompt\n", + "text = \"Funky groove with electric piano playing blue chords rhythmically\"\n", + "\n", + "# run the model\n", + "print(\"Generating...\") \n", + "output = model.generate(descriptions=[text], progress=True)\n", + "\n", + "# display the result\n", + "print(f\"Text: {text}\\n\")\n", + "display_audio(output, sample_rate=model.compression_model.sample_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can start adding temporal controls! We begin with conditioning on chord progressions:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Chords-conditional Generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.set_generation_params(\n", + " cfg_coef_all=1.5,\n", + " cfg_coef_txt=3.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from audiocraft.utils.notebook import display_audio\n", + "\n", + "# set textual prompt\n", + "text = \"Strings, woodwind, orchestral, symphony.\"\n", + "\n", + "# define chord progression\n", + "chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)]\n", + "\n", + "# run the model\n", + "print(\"Generating...\")\n", + "output = model.generate_music(descriptions=[text], chords=chords, progress=True)\n", + "\n", + "# display the result\n", + "print(f'Text: {text}')\n", + "print(f'Chord progression: {chords}')\n", + "display_audio(output, sample_rate=model.compression_model.sample_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we can condition the generation on drum tracks:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Drums-conditional Generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torchaudio\n", + "from audiocraft.utils.notebook import display_audio\n", + "\n", + "\n", + "# load drum prompt\n", + "drums_waveform, sr = torchaudio.load(\"../assets/sep_drums_1.mp3\")\n", + "\n", + "# set textual prompt \n", + "text = \"distortion guitars, heavy rock, catchy beat\"\n", + "\n", + "# run the model\n", + "print(\"Generating...\")\n", + "output = model.generate_music(\n", + " descriptions=[text],\n", + " drums_wav=drums_waveform,\n", + " drums_sample_rate=sr,\n", + " progress=True\n", + ")\n", + "\n", + "# display the result\n", + "print('drum prompt:')\n", + "display_audio(drums_waveform, sample_rate=sr)\n", + "print(f'Text: {text}')\n", + "display_audio(output, sample_rate=model.compression_model.sample_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also combine multiple temporal controls! Let's move on to generating with both chords and drums conditioning:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Drums + Chords conditioning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torchaudio\n", + "from audiocraft.utils.notebook import display_audio\n", + "\n", + "\n", + "# load drum prompt\n", + "drums_waveform, sr = torchaudio.load(\"../assets/sep_drums_1.mp3\")\n", + "\n", + "# set textual prompt \n", + "text = \"string quartet, orchestral, dramatic\"\n", + "\n", + "# define chord progression\n", + "chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)]\n", + "\n", + "# run the model\n", + "print(\"Generating...\")\n", + "output = model.generate_music(\n", + " descriptions=[text],\n", + " drums_wav=drums_waveform,\n", + " drums_sample_rate=sr,\n", + " chords=chords,\n", + " progress=True\n", + ")\n", + "\n", + "# display the result\n", + "print('drum prompt:')\n", + "display_audio(drums_waveform, sample_rate=sr)\n", + "print(f'Chord progression: {chords}')\n", + "print(f'Text: {text}')\n", + "display_audio(output, sample_rate=model.compression_model.sample_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Melody + Drums + Chords conditioning - inference example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "from demucs import pretrained\n", + "from demucs.apply import apply_model\n", + "from demucs.audio import convert_audio\n", + "import torch\n", + "from audiocraft.utils.notebook import display_audio\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# --------------------------\n", + "# First, choose file to load\n", + "# --------------------------\n", + "fnames = ['salience_1', 'salience_2']\n", + "chords = [\n", + " [('N', 0.0), ('Eb7', 1.088000000), ('C#', 4.352000000), ('D', 4.864000000), ('Dm7', 6.720000000), ('G7', 8.256000000), ('Am7b5/G', 9.152000000)], # for salience 1\n", + " [('N', 0.0), ('C', 0.320000000), ('Dm7', 3.456000000), ('Am', 4.608000000), ('F', 8.320000000), ('C', 9.216000000)] # for salience 2\n", + "]\n", + "file_idx = 0 # either 0 or 1\n", + "\n", + "\n", + "# ------------------------------------\n", + "# display audio, melody map and chords\n", + "# ------------------------------------\n", + "def plot_chromagram(tensor):\n", + " # Check if tensor is a PyTorch tensor\n", + " if not torch.is_tensor(tensor):\n", + " raise ValueError('Input should be a PyTorch tensor')\n", + " tensor = tensor.numpy().T # C, T\n", + " plt.figure(figsize=(20, 20))\n", + " plt.imshow(tensor, cmap='binary', interpolation='nearest', origin='lower')\n", + " plt.show()\n", + "\n", + "# load salience and display the corresponding wav\n", + "melody_prompt_wav, melody_prompt_sr = torchaudio.load(f\"../assets/{fnames[file_idx]}.wav\")\n", + "print(\"Source melody:\")\n", + "display_audio(melody_prompt_wav, sample_rate=melody_prompt_sr)\n", + "melody = torch.load(f\"../assets/{fnames[file_idx]}.th\", weights_only=True)\n", + "plot_chromagram(melody)\n", + "print(\"Chords:\")\n", + "print(chords[file_idx])\n", + "\n", + "# --------------------------------------------------\n", + "# use demucs to seperate the drums stem from src mix\n", + "# --------------------------------------------------\n", + "def _get_drums_stem(wav: torch.Tensor, sample_rate: int) -> torch.Tensor:\n", + " \"\"\"Get parts of the wav that holds the drums, extracting the main stems from the wav.\"\"\"\n", + " demucs_model = pretrained.get_model('htdemucs').to('cuda')\n", + " wav = convert_audio(\n", + " wav, sample_rate, demucs_model.samplerate, demucs_model.audio_channels) # type: ignore\n", + " stems = apply_model(demucs_model, wav.cuda().unsqueeze(0), device='cuda').squeeze(0)\n", + " drum_stem = stems[demucs_model.sources.index('drums')] # extract relevant stems for drums conditioning\n", + " return convert_audio(drum_stem.cpu(), demucs_model.samplerate, sample_rate, 1) # type: ignore\n", + "drums_wav = _get_drums_stem(melody_prompt_wav, melody_prompt_sr)\n", + "print(\"Separated drums:\")\n", + "display_audio(drums_wav, sample_rate=melody_prompt_sr)\n", + "\n", + "# ----------------------------------\n", + "# Generate using the loaded controls\n", + "# ----------------------------------\n", + "# these are free-form texts written randomly\n", + "texts = [\n", + " '90s rock with heavy drums and hammond',\n", + " '80s pop with groovy synth bass and drum machine',\n", + " 'folk song with leading accordion',\n", + "]\n", + "\n", + "print(\"Generating...\")\n", + "# replacing dynammic solver with simple euler solver\n", + "model.set_generation_params(cfg_coef_all=1.5, cfg_coef_txt=2.5, euler=True, euler_steps=50) # manually set with euler solver\n", + "output = model.generate_music(\n", + " descriptions=texts,\n", + " chords=chords[file_idx],\n", + " drums_wav=drums_wav,\n", + " drums_sample_rate=melody_prompt_sr,\n", + " melody_salience_matrix=melody.permute(1, 0),\n", + " progress=True\n", + ")\n", + "display_audio(output, sample_rate=model.compression_model.sample_rate)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jasco_dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/JASCO.md b/docs/JASCO.md new file mode 100644 index 00000000..cf723939 --- /dev/null +++ b/docs/JASCO.md @@ -0,0 +1,221 @@ +# JASCO: Joint Audio And Symbolic Conditioning for Temporally Controlled Text-To-Music Generation + +AudioCraft provides the code and models for JASCO, [Joint Audio And Symbolic Conditioning for Temporally Controlled Text-To-Music Generation][arxiv]. + +We present JASCO, a temporally controlled text-to-music generation model utilizing both symbolic and audio-based conditions. +JASCO can generate high-quality music samples conditioned on global text descriptions along with fine-grained local controls. +JASCO is based on the Flow Matching modeling paradigm together with a novel conditioning method, allowing for music generation controlled both locally (e.g., chords) and globally (text description). + +Check out our [sample page][sample_page] or test the available demo! + +We use ~16K hours of licensed music to train JASCO. + + +## Model Card + +See [the model card](../model_cards/JASCO_MODEL_CARD.md). + + +## Installation + +First, Please follow the AudioCraft installation instructions from the [README](../README.md). + +Then, download and install chord_extractor from [source](http://www.isophonics.net/nnls-chroma) + +See further required installation under **Data Preprocessing** section + +## Usage + +We currently offer two ways to interact with JASCO: +1. You can use the gradio demo locally by running [`python -m demos.jasco_app`](../demos/jasco_app.py), you can add `--share` to deploy a sharable space mounted on your device. +2. You can play with JASCO by running the jupyter notebook at [`demos/jasco_demo.ipynb`](../demos/jasco_demo.ipynb) locally. + +## API + +We provide a simple API and pre-trained models: +- `facebook/jasco-chords-drums-400M`: 400M model, text to music with chords and drums support, generates 10-second samples - [🤗 Hub](https://huggingface.co/facebook/jasco-chords-drums-400M) +- `facebook/jasco-chords-drums-1B`: 1B model, text to music with chords and drums support, generates 10-second samples - [🤗 Hub](https://huggingface.co/facebook/jasco-chords-drums-1B) + + +See after a quick example for using the API. + +```python +from audiocraft.models import JASCO + +model = JASCO.get_pretrained('facebook/jasco-chords-drums-400M', chords_mapping_path='../assets/chord_to_index_mapping.pkl') + +model.set_generation_params( + cfg_coef_all=5.0, + cfg_coef_txt=0.0 +) + +# set textual prompt +text = "Strings, woodwind, orchestral, symphony." + +# define chord progression +chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)] + +# run inference +output = model.generate_music(descriptions=[text], chords=chords, progress=True) + +audio_write('output', output.cpu().squeeze(0), model.sample_rate, strategy="loudness", loudness_compressor=True) +``` + +For more examples check out `demos/jasco_demo.ipynb` + +## 🤗 Transformers Usage + +Coming soon... + +## Data Preprocessing +In order to to use the JascoDataset with chords / melody conditioning, please follow the instructions below: + + +### Chords conditioning +To extract chords from your desired data follow the following steps: + +1. Prepare a `*.jsonl` containing list of absolute file paths in your dataset, should simply be absolute paths seperated by newlines. +2. Download and install chord_extractor from [source](http://www.isophonics.net/nnls-chroma) +3. For training purposes run: `python scripts/chords/extract_chords.py --src_jsonl_file= --target_output_dir=` +
+and then run: `python scripts/chords/build_chord_map.py --chords_folder= --output_directory=` + +4. For evaluation of our released models run: `python scripts/chords/extract_chords.py --src_jsonl_file= --target_output_dir= --path_to_pre_defined_map=` +
+and then run: `python scripts/chords/build_chord_map.py --chords_folder= --output_directory= --path_to_pre_defined_map=` + + +NOTE: current scripts assume that all audio files are of `.wav` format, some changes may be required if your data consists of other formats. + +NOTE: predefined chord mapping file is available in `assets` directory. + +### Melody conditioning + +This section relies on [Deepsalience repo](https://github.com/rabitt/ismir2017-deepsalience) with slight custom scripts written. + +#### Clone repo and create virtual environment +1. `git clone git@github.com:lonzi/ismir2017-deepsalience.git forked_deepsalience_repo` +2. `cd forked_deepsalience_repo` +3. `conda create --name deep_salience python=3.7` +4. `conda activate deep_salience` +5. `pip install -r requirements.txt` + + +#### Salience map dumps (of entire directory, using slurm job) + +##### From src dir + +1. create job array: `python predict/create_predict_saliency_cmds.py --src_dir= --out_dir= --n_shards= --multithread` +2. run job array: `sbatch predict_saliency.sh` + +##### From track list + +1. create job array: `python predict/create_predict_saliency_cmds.py --tracks_list=tracks_train.txt --out_dir= --n_shards=2 --multithread --sbatch_script_name=predict_saliency_train.sh --saliency_threshold=` +2. run job array: `sbatch predict_saliency_train.sh` + +tracks_train.txt: a list of track paths to process seperated by new lines + + +## Training + +The [JascoSolver](../audiocraft/solvers/jasco.py) implements JASCO's training pipeline. +conditional flow matching objective over the continuous extracted latents from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) +for more details on how to train such model). + +Note that **we do NOT provide any of the datasets** used for training JASCO. +We provide a dummy dataset containing just a few examples for illustrative purposes. + +Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section. + + +### Fine tuning existing models + +You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular + +```bash +# Using pretrained JASCO model. +dora run solver=jasco/chords_drums model/lm/model_scale=small continue_from=//pretrained/facebook/jasco-chords-drums-400M conditioner=jasco_chords_drums + +# Using another model you already trained with a Dora signature SIG. +dora run solver=jasco/chords_drums model/lm/model_scale=small continue_from=//sig/SIG conditioner=jasco_chords_drums + +# Or providing manually a path +dora run solver=jasco/chords_drums model/lm/model_scale=small conditioner=jasco_chords_drums continue_from=/checkpoints/my_other_xp/checkpoint.th +``` + +**Warning:** You are responsible for selecting the other parameters accordingly, in a way that make it compatible + with the model you are fine tuning. Configuration is NOT automatically inherited from the model you continue from. In particular make sure to select the proper `conditioner` and `model/lm/model_scale`. + +**Warning:** We currently do not support fine tuning a model with slightly different layers. If you decide + to change some parts, like the conditioning or some other parts of the model, you are responsible for manually crafting a checkpoint file from which we can safely run `load_state_dict`. + If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict + `{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix. + + +### Evaluation & Generation stage + +See [MusicGen](./MUSICGEN.md) + +### Playing with the model + +Once you have launched some experiments, you can easily get access +to the Solver with the latest trained model using the following snippet. + +```python +from audiocraft.solvers.jasco import JascoSolver + +solver = JascoSolver.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8) +solver.model +solver.dataloaders +``` + +### Importing / Exporting models + +We do not support currently loading a model from the Hugging Face implementation or exporting to it. +If you want to export your model in a way that is compatible with `audiocraft.models.JASCO` +API, you can run: + +```python +from audiocraft.utils import export +from audiocraft import train +xp = train.main.get_xp_from_sig('SIG_OF_LM') +export.export_lm(xp.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/state_dict.bin') +# You also need to bundle the EnCodec model you used !! +## Case 1) you trained your own +xp_encodec = train.main.get_xp_from_sig('SIG_OF_ENCODEC') +export.export_encodec(xp_encodec.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/compression_state_dict.bin') +## Case 2) you used a pretrained model. Give the name you used without the //pretrained/ prefix. +## This will actually not dump the actual model, simply a pointer to the right model to download. +export.export_pretrained_compression_model('facebook/encodec_32khz', '/checkpoints/my_audio_lm/compression_state_dict.bin') +``` + +Now you can load your custom model with: +```python +import audiocraft.models +jasco = audiocraft.models.JASCO.get_pretrained('/checkpoints/my_audio_lm/') +``` + + +### Learn more + +Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). + + +## Citation +``` +@misc{tal2024joint, + title={Joint Audio and Symbolic Conditioning for Temporally Controlled Text-to-Music Generation}, + author={Or Tal and Alon Ziv and Itai Gat and Felix Kreuk and Yossi Adi}, + year={2024}, + eprint={2406.10970}, + archivePrefix={arXiv}, + primaryClass={cs.SD} +} +``` + +## License + +See license information in the [model card](../model_cards/JASCO_MODEL_CARD.md). + +[arxiv]: https://arxiv.org/pdf/2406.10970 +[sample_page]: https://pages.cs.huji.ac.il/adiyoss-lab/JASCO/ diff --git a/model_cards/JASCO_MODEL_CARD.md b/model_cards/JASCO_MODEL_CARD.md new file mode 100644 index 00000000..dc6270c0 --- /dev/null +++ b/model_cards/JASCO_MODEL_CARD.md @@ -0,0 +1,152 @@ +## Model details + +**Organization developing the model:** The FAIR team of Meta AI. + +**Model date:** JASCO was trained in November 2024. + +**Model version:** This is the version 1 of the model. + +**Model type:** JASCO consists of an EnCodec model for audio tokenization, and a flow-matching model based on the transformer architecture for music modeling. +The model comes in different sizes: 400M and 1B; and currently have a two variant: text-to-music + {chords, drums} controls and text-to-music + {chords, drums, melody} controls. +JASCO is trained with condition dropout and could be used for inference with dropped conditions. + +**Paper or resources for more information:** More information can be found in the paper [Joint Audio And Symbolic Conditioning for Temporally Controlled Text-To-Music Generation][arxiv]. + +**Citation details:** + +Code was implemented by Or Tal and Alon Ziv. + +``` +@misc{tal2024joint, + title={Joint Audio and Symbolic Conditioning for Temporally Controlled Text-to-Music Generation}, + author={Or Tal and Alon Ziv and Itai Gat and Felix Kreuk and Yossi Adi}, + year={2024}, + eprint={2406.10970}, + archivePrefix={arXiv}, + primaryClass={cs.SD} +} +``` + +**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0. + +**Where to send questions or comments about the model:** Questions and comments about JASCO can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. + +## Intended use +**Primary intended use:** The primary use of JASCO is research on AI-based music generation, including: + +- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science +- Generation of music guided by text and (opt) local controls, to understand current abilities of generative AI models by machine learning amateurs + +**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. + +**Out-of-scope use cases:** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. + +## Metrics + +**Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark: + +- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish). +- CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model. +- Melody cosine similarity - pairwise comparison of chromagram extracted from refrence and generated waveforms. +- Onset F1 - pairwise comparison of onsets extracted from refrence and generated waveforms. +- Chords Intersection over union (IOU) - pairwise comparison of symbolic chords extracted from refrence and generated waveforms. + +Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes: + +- Overall quality of the music samples; +- Text relevance to the provided text input; +- Melody match w.r.t reference signal; +- Drum beat match w.r.t reference signal; + +More details on performance measures and human studies can be found in the [paper][arxiv]. + +**Decision thresholds:** Not applicable. + +## Evaluation datasets + +The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set. + +## Training datasets + +The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing. + +## Evaluation results + +Below are the objective metrics obtained on MusicCaps with the released model. + +Text-to-music with temporal controls + +| Model | Frechet Audio Distance | Text Consistency | Chord IOU | Onset F1 | Melody Cosine Similarity | +|---|---|---|---|---|---| +| facebook/jasco-chords-drums-400M | 5.866 | 0.284 | 0.588 | 0.328 | 0.096 | +| facebook/jasco-chords-drums-1B | 5.587 | 0.291 | 0.589 | 0.331 | 0.097 | +| facebook/jasco-chords-drums-melody-400M | 4.730 | 0.317 | 0.689 | 0.379 | 0.423 | +| facebook/jasco-chords-drums-melody-1B | 5.098 | 0.313 | 0.690 | 0.378 | 0.427 | + +Note: reccommanded CFG coefficient ratio stands at 1:2 - 'all':'text', results for chords-drums-melody were sampled with all: 1.75, text: 3.5 + +Text-to-music w.o temporal controls (dropped) + + +| Model | Frechet Audio Distance | Text Consistency | Chord IOU | Onset F1 | Melody Cosine Similarity | +|---|---|---|---|---|---| +| facebook/jasco-chords-drums-400M | 5.648 | 0.272 | 0.070 | 0.204 | 0.093 | +| facebook/jasco-chords-drums-1B | 5.602 | 0.281 | 0.071 | 0.214 | 0.093 | +| facebook/jasco-chords-drums-melody-400M | 5.816 | 0.293 | 0.091 | 0.203 | 0.098 | +| facebook/jasco-chords-drums-melody-1B | 5.470 | 0.297 | 0.097 | 0.208 | 0.097 | + +## Limitations and biases + +**Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on ~16k hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model. + +**Mitigations:** +Pre-trained models were used to obtain pseudo symbolic supervision. Refer to **Data Preprocessing** section in [Jasco's docs](../docs/JASCO.md) + +**Limitations:** + +- The model is not able to generate realistic vocals. +- The model has been trained with English descriptions and will not perform as well in other languages. +- The model does not perform equally well for all music styles and cultures. +- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering and experimentation with classifier free guidance coefficients may be required to obtain satisfying results. +- Model could be sensitive to CFG coefficients as melody introduces a strong bias that would require higher text coefficient during generation, some hyper-parameter search could be necessary to obtain desired results. + +**Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive. + +**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data. + +**Use cases:** Users must be aware of the biases, limitations and risks of the model. JASCO is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. + +## API + +We provide a simple API and pre-trained models: +- `facebook/jasco-chords-drums-400M`: 400M model, text to music with chords and drums support, generates 10-second samples - [🤗 Hub](https://huggingface.co/facebook/jasco-chords-drums-400M) +- `facebook/jasco-chords-drums-1B`: 1B model, text to music with chords and drums support, generates 10-second samples - [🤗 Hub](https://huggingface.co/facebook/jasco-chords-drums-1B) +- `facebook/jasco-chords-drums-melody-400M`: 400M model, text to music with chords, drums and melody support, generates 10-second samples - [🤗 Hub](https://huggingface.co/facebook/jasco-chords-drums-melody-400M) +- `facebook/jasco-chords-drums-melody-1B`: 1B model, text to music with chords, drums and melody support, generates 10-second samples - [🤗 Hub](https://huggingface.co/facebook/jasco-chords-drums-melody-1B) + + +See after a quick example for using the API. + +```python +from audiocraft.models import JASCO + +model = JASCO.get_pretrained('facebook/jasco-chords-drums-400M', chords_mapping_path='../assets/chord_to_index_mapping.pkl') + +model.set_generation_params( + cfg_coef_all=1.5, + cfg_coef_txt=0.5 +) + +# set textual prompt +text = "Strings, woodwind, orchestral, symphony." + +# define chord progression +chords = [('C', 0.0), ('D', 2.0), ('F', 4.0), ('Ab', 6.0), ('Bb', 7.0), ('C', 8.0)] + +# run inference +output = model.generate_music(descriptions=[text], chords=chords, progress=True) + +audio_write('output', output.cpu().squeeze(0), model.sample_rate, strategy="loudness", loudness_compressor=True) +``` + +[arxiv]: https://arxiv.org/pdf/2406.10970 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 74b9ef8b..99c2bd64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ torchvision==0.16.0 torchtext==0.16.0 pesq pystoi +torchdiffeq diff --git a/scripts/chords/build_chord_maps.py b/scripts/chords/build_chord_maps.py new file mode 100644 index 00000000..410875ac --- /dev/null +++ b/scripts/chords/build_chord_maps.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import os +import pickle +from tqdm import tqdm +import argparse + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--chords_folder', type=str, required=True, + help='path to directory containing parsed chords files') + parser.add_argument('--output_directory', type=str, required=False, + help='path to output directory to generate code maps to, \ + if not given - chords_folder would be used', default='') + parser.add_argument('--path_to_pre_defined_map', type=str, required=False, + help='for evaluation purpose, use pre-defined chord-to-index map', default='') + args = parser.parse_args() + return args + + +def get_chord_dict(chord_folder: str): + chord_dict = {} + distinct_chords = set() + + chord_to_index = {} # Mapping between chord and index + index_counter = 0 + + for filename in tqdm(os.listdir(chord_folder)): + if filename.endswith(".chords"): + idx = filename.split(".")[0] + + with open(os.path.join(chord_folder, filename), "rb") as file: + chord_data = pickle.load(file) + + for chord, _ in chord_data: + distinct_chords.add(chord) + if chord not in chord_to_index: + chord_to_index[chord] = index_counter + index_counter += 1 + + chord_dict[idx] = chord_data + chord_to_index["UNK"] = index_counter + return chord_dict, distinct_chords, chord_to_index + + +def get_predefined_chord_to_index_map(path_to_chords_to_index_map: str): + def inner(chord_folder: str): + chords_to_index = pickle.load(open(path_to_chords_to_index_map, "rb")) + distinct_chords = set(chords_to_index.keys()) + chord_dict = {} + for filename in tqdm(os.listdir(chord_folder), desc=f'iterating: {chord_folder}'): + if filename.endswith(".chords"): + idx = filename.split(".")[0] + + with open(os.path.join(chord_folder, filename), "rb") as file: + chord_data = pickle.load(file) + + chord_dict[idx] = chord_data + return chord_dict, distinct_chords, chords_to_index + return inner + + +if __name__ == "__main__": + '''This script processes and maps chord data from a directory of parsed chords files, + generating two output files: a combined chord dictionary and a chord-to-index mapping.''' + args = parse_args() + chord_folder = args.chords_folder + output_dir = args.output_directory + if output_dir == '': + output_dir = chord_folder + func = get_chord_dict + if args.path_to_pre_defined_map != "": + func = get_predefined_chord_to_index_map(args.path_to_pre_defined_map) + + chord_dict, distinct_chords, chord_to_index = func(chord_folder) + + # Save the combined chord dictionary as a pickle file + combined_filename = os.path.join(output_dir, "combined_chord_dict.pkl") + with open(combined_filename, "wb") as file: + pickle.dump(chord_dict, file) + + # Save the chord-to-index mapping as a pickle file + mapping_filename = os.path.join(output_dir, "chord_to_index_mapping.pkl") + with open(mapping_filename, "wb") as file: + pickle.dump(chord_to_index, file) + + print("Number of distinct chords:", len(distinct_chords)) + print("Chord dictionary:", chord_to_index) diff --git a/scripts/chords/extract_chords.py b/scripts/chords/extract_chords.py new file mode 100644 index 00000000..f6bf727c --- /dev/null +++ b/scripts/chords/extract_chords.py @@ -0,0 +1,73 @@ +# Env - chords_extraction on devfair + +import pickle +import argparse +from chord_extractor.extractors import Chordino # type: ignore +from chord_extractor import clear_conversion_cache, LabelledChordSequence # type: ignore +import os +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--src_jsonl_file', type=str, required=True, + help='abs path to .jsonl file containing list of absolute file paths seperated by new line') + parser.add_argument('--target_output_dir', type=str, required=True, + help='target directory to save parsed chord files to, individual files will be saved inside') + parser.add_argument("--override", action="store_true") + args = parser.parse_args() + return args + + +def save_to_db_cb(tgt_dir: str): + # Every time one of the files has had chords extracted, receive the chords here + # along with the name of the original file and then run some logic here, e.g. to + # save the latest data to DB + def inner(results: LabelledChordSequence): + path = results.id.split(".wav") + + sequence = [(item.chord, item.timestamp) for item in results.sequence] + + if len(path) != 2: + print("Something") + print(path) + else: + file_idx = path[0].split("/")[-1] + with open(f"{tgt_dir}/{file_idx}.chords", "wb") as f: + # dump the object to the file + pickle.dump(sequence, f) + return inner + + +if __name__ == "__main__": + '''This script extracts chord data from a list of audio files using the Chordino extractor, + and saves the extracted chords to individual files in a target directory.''' + print("parsed args") + args = parse_args() + files_to_extract_from = list() + with open(args.src_jsonl_file, "r") as json_file: + for line in tqdm(json_file.readlines()): + # fpath = json.loads(line.replace("\n", ""))['path'] + fpath = line.replace("\n", "") + if not args.override: + fname = fpath.split("/")[-1].replace(".wav", ".chords") + if os.path.exists(f"{args.target_output_dir}/{fname}"): + continue + files_to_extract_from.append(line.replace("\n", "")) + + print(f"num files to parse: {len(files_to_extract_from)}") + + chordino = Chordino() + + # Optionally clear cache of file conversions (e.g. wav files that have been converted from midi) + clear_conversion_cache() + + # Run bulk extraction + res = chordino.extract_many( + files_to_extract_from, + callback=save_to_db_cb(args.target_output_dir), + num_extractors=80, + num_preprocessors=80, + max_files_in_cache=400, + stop_on_error=False, + ) diff --git a/scripts/chords/job_array_example.sh b/scripts/chords/job_array_example.sh new file mode 100644 index 00000000..5a1ce69a --- /dev/null +++ b/scripts/chords/job_array_example.sh @@ -0,0 +1,17 @@ +#!/bin/zsh +#SBATCH --job-name=my_job_array +#SBATCH --array=0-N # adjust the range of indices as needed +#SBATCH --output=logs/%A_%a.out # output file name format, this assumes there exists a /logs directory +#SBATCH --error=logs/%A_%a.err # error file name format, this assumes there exists a /logs directory +#SBATCH --time=01:00:00 # adjust the time limit as needed +#SBATCH --nodes=1 # adjust the number of nodes as needed +#SBATCH --ntasks-per-node=1 # adjust the number of tasks per node as needed +#SBATCH --cpus-per-task=8 # adjust the number of CPUs per task as needed +#SBATCH --mem-per-cpu=16G # adjust the memory per CPU as needed + +# Load any necessary modules or dependencies +conda activate your_env + +# run extraction of chords in job array +python scripts/chords/extract_chords.py --src_jsonl_file /path/to/parsed/filepaths_${SLURM_ARRAY_TASK_ID}.jsonl --target_output_dir /target/directory/to/save/chords/to --path_to_pre_defined_map /path/to/predefined/chord_to_index_mapping.pkl +