| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """MaMMUT configuration.""" |
|
|
|
|
| from transformers import (CLIPConfig, CLIPTextConfig, CLIPVisionConfig, PretrainedConfig, AutoConfig) |
| from typing import Callable, List, Optional, Sequence, Tuple, Union |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
| class MultimodalConfig(PretrainedConfig): |
|
|
| model_type = "mammut_text_model" |
|
|
| def __init__( |
| self, |
| mlp_ratio: int = 4, |
| dim_head: int = 64, |
| heads: int = 8, |
| n_queries: int = 256, |
| attn_pooler_heads: int = 8, |
| cross_attn_ratio: int = 1, |
| does_full_decoding: bool = False, |
| output_tokens: bool = False, |
| has_mlp: bool = True, |
| context_length: int = 77, |
| vocab_size: int = 49408, |
| hidden_size: int = 1024, |
| layers: int = 12, |
| batch_first: bool = True, |
| **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] |
| ): |
| super().__init__() |
| self.mlp_ratio = mlp_ratio |
| self.dim_head = dim_head |
| self.heads = heads |
| self.n_queries = n_queries |
| self.attn_pooler_heads = attn_pooler_heads |
| self.cross_attn_ratio = cross_attn_ratio |
| self.does_full_decoding = does_full_decoding |
| self.output_tokens = output_tokens |
| self.has_mlp = has_mlp |
| self.context_length = context_length |
| self.vocab_size = vocab_size |
| self.width = hidden_size |
| self.layers = layers |
| self.batch_first = batch_first |
| for key, value in kwargs.items(): |
| setattr(self, key, value) |
|
|
|
|
|
|
| class MammutTextConfig(MultimodalConfig,CLIPTextConfig): |
| model_type = "mammut_text_model" |
| base_config_key = "text_config" |
|
|
| def __init__( |
| self, |
| mlp_ratio: int = 4, |
| num_attention_heads: int = 8, |
| n_queries: int = 256, |
| attn_pooler_heads: int = 8, |
| cross_attn_ratio: int = 1, |
| does_full_decoding: bool = False, |
| output_tokens: bool = False, |
| has_mlp: bool = True, |
| max_position_embeddings: int = 77, |
| vocab_size: int = 49408, |
| num_hidden_layers: int = 12, |
| hidden_size: int = 1024, |
| attention_dropout: float = 0.0, |
| hidden_act: str = "gelu", |
| layer_norm_eps: float = 1e-5, |
| intermediate_size: Optional[int] = None, |
| initializer_factor: float = 0.02, |
| logit_scale_init_value: float = 2.6592, |
| **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] |
| ): |
| super().__init__( |
| mlp_ratio=mlp_ratio, |
| num_attention_heads=num_attention_heads, |
| n_queries=n_queries, |
| attn_pooler_heads=attn_pooler_heads, |
| cross_attn_ratio=cross_attn_ratio, |
| does_full_decoding=does_full_decoding, |
| output_tokens=output_tokens, |
| has_mlp=has_mlp, |
| vocab_size=vocab_size, |
| hidden_size=hidden_size, |
| num_hidden_layers=num_hidden_layers, |
| attention_dropout=attention_dropout, |
| logit_scale_init_value=logit_scale_init_value, |
| max_position_embeddings=max_position_embeddings, |
| layer_norm_eps=layer_norm_eps, |
| intermediate_size=intermediate_size, |
| initializer_factor=initializer_factor, |
| hidden_act=hidden_act, |
| **kwargs |
| ) |
|
|
|
|
| self.logit_scale_init_value = logit_scale_init_value |
| self.does_full_decoding = does_full_decoding |
| self.output_tokens = output_tokens |
| self.architectures = ["MammutTextModel"] |
| self.hidden_size = hidden_size |
| self.num_attention_heads = num_attention_heads |
|
|
| class MammutVisionConfig(CLIPVisionConfig): |
| model_type = "mammut_vision_model" |
| base_config_key = "vision_config" |
|
|
| def __init__( |
| self, |
| mlp_ratio: int = 4, |
| dim_head: int = 64, |
| num_attention_heads: int = 8, |
| n_queries: int = 256, |
| attn_pooler_heads: int = 8, |
| cross_attn_ratio: int = 1, |
| does_full_decoding: bool = False, |
| output_tokens: bool = False, |
| has_mlp: bool = True, |
| image_size: int = 224, |
| patch_size: int = 16, |
| width: int = 1024, |
| layers: int = 12, |
| **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] |
| ): |
| super().__init__( |
| mlp_ratio=mlp_ratio, |
| dim_head=dim_head, |
| num_attention_heads=num_attention_heads, |
| n_queries=n_queries, |
| attn_pooler_heads=attn_pooler_heads, |
| cross_attn_ratio=cross_attn_ratio, |
| does_full_decoding=does_full_decoding, |
| output_tokens=output_tokens, |
| has_mlp=has_mlp, |
| image_size=image_size, |
| patch_size=patch_size, |
| width=width, |
| layers=layers, |
| **kwargs |
| ) |
|
|
| self.num_attention_heads = num_attention_heads |
|
|
| class MammutConfig(CLIPConfig): |
| model_type = "mammut" |
|
|
| def __init__( |
| self, |
| mlp_ratio: int = 4, |
| dim_head: int = 64, |
| num_attention_heads: int = 8, |
| n_queries: int = 256, |
| attn_pooler_heads: int = 8, |
| cross_attn_ratio: int = 1, |
| does_full_decoding: bool = False, |
| output_tokens: bool = False, |
| has_mlp: bool = True, |
| text_config: Optional[MammutTextConfig] = None, |
| vision_config: Optional[MammutVisionConfig] = None, |
| projection_dim: int = 768, |
| logit_scale_init_value: float = 2.6592, |
| **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] |
| ): |
| kwargs["architectures"] = ["MammutModel"] |
| super().__init__( |
| mlp_ratio=mlp_ratio, |
| dim_head=dim_head, |
| num_attention_heads=num_attention_heads, |
| n_queries=n_queries, |
| attn_pooler_heads=attn_pooler_heads, |
| cross_attn_ratio=cross_attn_ratio, |
| does_full_decoding=does_full_decoding, |
| output_tokens=output_tokens, |
| has_mlp=has_mlp, |
| **kwargs |
| ) |
| self.text_config = MammutTextConfig(**text_config) if text_config is not None else MammutTextConfig() |
| self.vision_config = MammutVisionConfig(**vision_config) if vision_config is not None else MammutVisionConfig() |
| self.text_config.architectures = ["MammutTextModel"] |
| self.vision_config.architectures = ["MammutVisionModel"] |
| self.projection_dim = projection_dim |
| self.hidden_size = self.text_config.hidden_size |
| self.logit_scale_init_value = logit_scale_init_value |
| self.architectures = ["MammutModel"] |
|
|
| self.does_full_decoding = does_full_decoding |
| self.output_tokens = output_tokens |
|
|
| def _post_init(self): |
| if self.logit_scale_init_value is not None: |
| setattr(self.text_config, "logit_scale_init_value", self.logit_scale_init_value) |
|
|
| super()._post_init() |
|
|
|
|
| AutoConfig.register("mammut", MammutConfig) |