| from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| |
| class MeshExpert(nn.Module): |
| def __init__(self, config: MeshConfig): |
| super().__init__() |
| self.fc1 = nn.Linear(config.hidden_size, config.expert_intermediate_size) |
| self.gelu = nn.GELU() |
| self.fc2 = nn.Linear(config.expert_intermediate_size, config.hidden_size) |
|
|
| def forward(self, x): |
| return self.fc2(self.gelu(self.fc1(x))) |
|
|