"""CM3P model configuration""" from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging logger = logging.get_logger(__name__) class CM3PMetadataConfig(PretrainedConfig): model_type = "CM3PMetadata" base_config_key = "metadata_config" def __init__( self, cls_embed=True, projection_dim=512, initializer_factor=1.0, vocab_size=1000, hidden_size=256, intermediate_size=512, num_hidden_layers=6, num_attention_heads=4, hidden_activation="gelu", max_position_embeddings=128, initializer_range=0.02, initializer_cutoff_factor=2.0, norm_eps=1e-5, norm_bias=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, global_rope_theta=10000.0, attention_bias=False, attention_dropout=0.0, global_attn_every_n_layers=1, local_attention=128, local_rope_theta=10000.0, embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, decoder_bias=True, deterministic_flash_attn=False, reference_compile=None, **kwargs, ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) self.cls_embed = cls_embed self.projection_dim = projection_dim self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.initializer_range = initializer_range self.initializer_cutoff_factor = initializer_cutoff_factor self.norm_eps = norm_eps self.norm_bias = norm_bias self.global_rope_theta = global_rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation self.global_attn_every_n_layers = global_attn_every_n_layers self.local_attention = local_attention self.local_rope_theta = local_rope_theta self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout self.decoder_bias = decoder_bias self.deterministic_flash_attn = deterministic_flash_attn self.reference_compile = reference_compile def to_dict(self): output = super().to_dict() output.pop("reference_compile", None) return output class CM3PAudioConfig(PretrainedConfig): model_type = "CM3PAudio" base_config_key = "audio_config" def __init__( self, hidden_size=512, intermediate_size=1024, num_hidden_layers=6, num_attention_heads=8, hidden_activation="gelu", max_position_embeddings=4096, initializer_range=0.02, initializer_cutoff_factor=2.0, norm_eps=1e-5, norm_bias=False, global_rope_theta=160000.0, attention_bias=False, attention_dropout=0.0, global_attn_every_n_layers=3, local_attention=128, local_rope_theta=10000.0, embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, decoder_bias=True, deterministic_flash_attn=False, reference_compile=None, projector_intermediate_size=2048, # 4 * hidden_size for a 4x reduction in tokens projector_dim=768, projector_hidden_act="gelu", sample_rate: int = 16000, n_ftt: int = 2048, n_mels: int = 80, hop_length: int = 128, f_min: int = 0, f_max: int = 8000, pad_mode: str = "constant", **kwargs, ): super().__init__(**kwargs) self.vocab_size = 1 self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.initializer_range = initializer_range self.initializer_cutoff_factor = initializer_cutoff_factor self.norm_eps = norm_eps self.norm_bias = norm_bias self.global_rope_theta = global_rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation self.global_attn_every_n_layers = global_attn_every_n_layers self.local_attention = local_attention self.local_rope_theta = local_rope_theta self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout self.decoder_bias = decoder_bias self.deterministic_flash_attn = deterministic_flash_attn self.reference_compile = reference_compile self.projector_intermediate_size = projector_intermediate_size self.projector_dim = projector_dim self.projector_hidden_act = projector_hidden_act self.sample_rate = sample_rate self.n_ftt = n_ftt self.n_mels = n_mels self.hop_length = hop_length self.f_min = f_min self.f_max = f_max self.pad_mode = pad_mode def to_dict(self): output = super().to_dict() output.pop("reference_compile", None) return output class CM3PBeatmapConfig(PretrainedConfig): model_type = "CM3PBeatmap" is_composition = True base_config_key = "beatmap_config" sub_configs = {"audio_config": CM3PAudioConfig} def __init__( self, audio_config: dict = None, audio_sos_token_id=3164, audio_eos_token_id=3165, audio_token_id=3166, cls_embed=True, projection_dim=512, initializer_factor=1.0, vocab_size=3167, hidden_size=768, intermediate_size=1152, num_hidden_layers=22, num_attention_heads=12, hidden_activation="gelu", max_position_embeddings=8192, initializer_range=0.02, initializer_cutoff_factor=2.0, norm_eps=1e-5, norm_bias=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, global_rope_theta=160000.0, attention_bias=False, attention_dropout=0.0, global_attn_every_n_layers=3, local_attention=128, local_rope_theta=10000.0, embedding_dropout=0.0, mlp_bias=False, mlp_dropout=0.0, decoder_bias=True, classifier_bias=False, classifier_activation="gelu", deterministic_flash_attn=False, sparse_prediction=False, sparse_pred_ignore_index=-100, reference_compile=None, repad_logits_with_grad=False, attn_implementation: str = None, **kwargs, ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, attn_implementation=attn_implementation, **kwargs, ) if audio_config is None: audio_config = {} logger.info("`audio_config` is `None`. Initializing the `CM3PAudioConfig` with default values.") self.audio_config = CM3PAudioConfig( attn_implementation=attn_implementation, **audio_config ) self.audio_sos_token_id = audio_sos_token_id self.audio_eos_token_id = audio_eos_token_id self.audio_token_id = audio_token_id self.cls_embed = cls_embed self.projection_dim = projection_dim self.initializer_factor = initializer_factor self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.initializer_range = initializer_range self.initializer_cutoff_factor = initializer_cutoff_factor self.norm_eps = norm_eps self.norm_bias = norm_bias self.global_rope_theta = global_rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation self.global_attn_every_n_layers = global_attn_every_n_layers self.local_attention = local_attention self.local_rope_theta = local_rope_theta self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout self.decoder_bias = decoder_bias self.classifier_bias = classifier_bias self.classifier_activation = classifier_activation self.deterministic_flash_attn = deterministic_flash_attn self.sparse_prediction = sparse_prediction self.sparse_pred_ignore_index = sparse_pred_ignore_index self.reference_compile = reference_compile self.repad_logits_with_grad = repad_logits_with_grad def to_dict(self): output = super().to_dict() output.pop("reference_compile", None) return output class CM3PConfig(PretrainedConfig): model_type = "CM3P" is_composition = True sub_configs = {"metadata_config": CM3PMetadataConfig, "beatmap_config": CM3PBeatmapConfig} def __init__( self, metadata_config=None, beatmap_config=None, projection_dim=512, logit_scale_init_value=2.6592, initializer_factor=1.0, initializer_range=0.02, loss_type=None, has_decoder_head=False, attn_implementation: str = None, **kwargs ): super().__init__( attn_implementation=attn_implementation, **kwargs ) if metadata_config is None: metadata_config = {} logger.debug("`metadata_config` is `None`. Initializing the `CM3PMetadataConfig` with default values.") if beatmap_config is None: beatmap_config = {} logger.debug("`beatmap_config` is `None`. initializing the `CM3PBeatmapConfig` with default values.") self.metadata_config = CM3PMetadataConfig( attn_implementation=attn_implementation, **metadata_config ) self.beatmap_config = CM3PBeatmapConfig( attn_implementation=attn_implementation, **beatmap_config ) self.projection_dim = projection_dim self.logit_scale_init_value = logit_scale_init_value self.initializer_factor = initializer_factor self.initializer_range = initializer_range self.loss_type = loss_type self.has_decoder_head = has_decoder_head AutoConfig.register("CM3PMetadata", CM3PMetadataConfig) AutoConfig.register("CM3PAudio", CM3PAudioConfig) AutoConfig.register("CM3PBeatmap", CM3PBeatmapConfig) AutoConfig.register("CM3P", CM3PConfig) __all__ = ["CM3PConfig", "CM3PMetadataConfig", "CM3PAudioConfig", "CM3PBeatmapConfig"]