robinfaro commited on
Commit
6333ecf
·
verified ·
1 Parent(s): c0246f4

Upload custom config and model files

Browse files
Files changed (4) hide show
  1. __init__.py +1 -1
  2. aux_losses.py +88 -0
  3. configuration.py +42 -0
  4. moe.py +133 -0
__init__.py CHANGED
@@ -1,2 +1,2 @@
1
  from .configuration_moegpt import MoEGPTConfig
2
- from .modeling_moegpt import MoEGPTForCausalLM
 
1
  from .configuration_moegpt import MoEGPTConfig
2
+ from .modeling import MoEGPTForCausalLM
aux_losses.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_mean(x, dim):
7
+ return torch.logsumexp(x, dim=dim) - torch.log(
8
+ torch.tensor(x.shape[dim], dtype=torch.float32)
9
+ )
10
+
11
+
12
+ def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True):
13
+ """Entropy regularization for the router."""
14
+
15
+ entropy_l = lambda l: -(l * l.exp()).sum(-1)
16
+ # softmax over experts
17
+ # logits: [batch_size * sequence_length, num_experts]
18
+ logprobs = F.log_softmax(logits, dim=-1)
19
+ if mean_over_batch:
20
+ # take mean probability over batch
21
+ logprobs = log_mean(logprobs, 0)
22
+
23
+ return -entropy_l(logprobs).mean()
24
+
25
+
26
+ # two losses below are adapted from
27
+ # https://github.com/google/flaxformer/blob/b725bd2a51d70e866d819c92de166fbf24425e6a/flaxformer/architectures/moe/routing.py
28
+ def load_balancing_loss(logits: torch.Tensor, expert_indices: torch.Tensor) -> float:
29
+ """Computes auxiliary load balancing loss as in Switch Transformer.
30
+
31
+ See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
32
+ implements the loss function presented in equations (4) - (6). It aims to
33
+ penalize those cases where the routing between experts is unbalanced.
34
+
35
+ Args:
36
+ logits: logits assigned to each expert per token. Shape:
37
+ <float32>[batch_size * sequence_length, num_experts].
38
+ expert_indices: <int>[batch_size * sequence_length, num_selected_experts]
39
+ indices identifying the top num_selected_experts for a given token.
40
+
41
+ Returns:
42
+ The auxiliary loss.
43
+ """
44
+ # num_token = batch_size * sequence_length
45
+ num_token, num_experts = logits.shape
46
+
47
+ # Shape: [batch_size * sequence_length, num_selected_experts, num_experts].
48
+ expert_mask = F.one_hot(expert_indices, num_experts)
49
+ # For a given token, determine if it was routed to a given expert.
50
+ # Shape: [batch_size * sequence_length, num_experts]
51
+ expert_mask, _ = torch.max(expert_mask, dim=-2)
52
+
53
+ # shape [num_experts]
54
+ tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32)
55
+
56
+ # compute router probability per expert in log space for numerical stability
57
+ logprobs = F.log_softmax(logits, dim=-1)
58
+ # take mean probability over batch
59
+ # shape [num_experts]
60
+ logprobs = log_mean(logprobs, dim=0)
61
+ router_prob_per_expert = torch.exp(logprobs)
62
+ return (
63
+ torch.mean( # mean over experts
64
+ tokens_per_expert * router_prob_per_expert,
65
+ dtype=torch.float32,
66
+ )
67
+ * num_experts
68
+ )
69
+
70
+
71
+ def router_z_loss(router_logits: torch.Tensor) -> float:
72
+ """Compute router z-loss.
73
+
74
+ The router z-loss was introduced in Designing Effective Sparse Expert Models
75
+ (https://arxiv.org/abs/2202.08906). It encourages router logits to remain
76
+ small in an effort to improve stability.
77
+
78
+ Args:
79
+ router_logits: <float>[batch_size * sequence_length, num_experts]
80
+ router logits
81
+
82
+ Returns:
83
+ Scalar router z-loss.
84
+ """
85
+ num_tokens, _ = router_logits.shape
86
+ log_z = torch.logsumexp(router_logits, dim=-1)
87
+ z_loss = log_z**2
88
+ return torch.sum(z_loss, dtype=torch.float32) / (num_tokens)
configuration.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class MoEGPTConfig(PretrainedConfig):
4
+ model_type = "moegpt"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=50304,
9
+ n_embd=768,
10
+ n_layer=12,
11
+ n_head=12,
12
+ sequence_length=1024,
13
+ moe=False,
14
+ moe_routing="standard_gating",
15
+ moe_num_experts=4,
16
+ moe_num_experts_per_tok=2,
17
+ moe_softmax_order="softmax_topk",
18
+ moe_router_loss="load_balancing_z_loss",
19
+ moe_aux_loss_factor=0.01,
20
+ moe_z_loss_factor=1.0,
21
+ mlp_dim_exp_factor=1.0,
22
+ dropout=0.0,
23
+ bias=False,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(**kwargs)
27
+ self.vocab_size = vocab_size
28
+ self.n_embd = n_embd
29
+ self.n_layer = n_layer
30
+ self.n_head = n_head
31
+ self.sequence_length = sequence_length
32
+ self.moe = moe
33
+ self.moe_routing = moe_routing
34
+ self.moe_num_experts = moe_num_experts
35
+ self.moe_num_experts_per_tok = moe_num_experts_per_tok
36
+ self.moe_softmax_order = moe_softmax_order
37
+ self.moe_router_loss = moe_router_loss
38
+ self.moe_aux_loss_factor = moe_aux_loss_factor
39
+ self.moe_z_loss_factor = moe_z_loss_factor
40
+ self.mlp_dim_exp_factor = mlp_dim_exp_factor
41
+ self.dropout = dropout
42
+ self.bias = bias
moe.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple MoE routing implementations that replace the MLP block in a standard transformer.
3
+ References:
4
+ 1) Mistral Source for Mixtral MoEs:
5
+ https://github.com/mistralai/mistral-src
6
+ 2) ST-MoE:
7
+ https://arxiv.org/abs/2202.08906
8
+ 3) Our notepad of MoE resources:
9
+ https://docs.google.com/document/d/1NuQ5jr7V-Jv1ui7p4KrxO_JTz-7bpYcYMmh49EeJ-QA/edit?usp=sharing
10
+ """
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import bisect
16
+
17
+
18
+
19
+ class MoE(nn.Module):
20
+ """
21
+ Simplest MoE implementation with a linear router and softmax over experts.
22
+
23
+ Note that in this implementation, we simply loop over the experts and
24
+ aggregate the results. This is not the most efficient way to do it, but
25
+ it also avoids the large memory overhead _and_ has no token dropping
26
+ (because we do not need the capacity factor).
27
+ """
28
+
29
+ def __init__(self, config, mlp):
30
+ super().__init__()
31
+ assert config.moe_num_experts > 0
32
+ self.experts = nn.ModuleList(
33
+ [mlp(config=config) for _ in range(config.moe_num_experts)]
34
+ )
35
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False)
36
+ self.top_k = config.moe_num_experts_per_tok
37
+ self.softmax_order = config.moe_softmax_order
38
+
39
+ def forward(self, inputs: torch.Tensor):
40
+ # [batch_size * sequence_length, n_embd]
41
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
42
+ # [batch_size * sequence_length, num_experts]
43
+ router_logits = self.router(inputs_squashed)
44
+
45
+ # note that selected experts will be the same for all orders:
46
+ # softmax doesnt change top-k, but the weights are different
47
+ if self.softmax_order == "softmax_topk":
48
+ all_probs = F.softmax(router_logits, dim=1)
49
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
50
+ elif self.softmax_order == "topk_softmax":
51
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
52
+ weights = F.softmax(weights, dim=-1)
53
+ else:
54
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
55
+
56
+ results = torch.zeros_like(inputs_squashed)
57
+ # naive looping over experts
58
+ for i, expert in enumerate(self.experts):
59
+ batch_idx, nth_expert = torch.where(selected_experts == i)
60
+ output, _ = expert(inputs_squashed[batch_idx])
61
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output
62
+
63
+ # return results and router logits (for aux loss calculation later)
64
+ return results.view_as(inputs), {
65
+ "router_logits": router_logits,
66
+ "selected_experts": selected_experts,
67
+ }
68
+
69
+
70
+ class DummyExpert(nn.Module):
71
+ def __init__(self, output_size: int):
72
+ super().__init__()
73
+ self._output_size = output_size
74
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
75
+ out = torch.zeros((self._output_size,), device=inputs.device)
76
+ return out, {}
77
+
78
+
79
+
80
+ class MaskedMoE(MoE):
81
+ def __init__(self, config, mlp):
82
+ super().__init__(config, mlp)
83
+ self._sequence_length = config.sequence_length
84
+ self.experts.append(DummyExpert(config.n_embd))
85
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts+1, bias=False)
86
+
87
+
88
+ def forward(self, inputs: torch.Tensor, mask: torch.Tensor):
89
+ inputs_squashed = inputs.view(-1, inputs.shape[-1])
90
+ router_logits = self.router(inputs_squashed)
91
+ mask = torch.cat(
92
+ (mask, torch.ones((mask.shape[0], 1), device=mask.device)),
93
+ dim=1
94
+ )
95
+ mask = mask.repeat_interleave(self._sequence_length, dim=0)
96
+ router_logits = router_logits*mask
97
+
98
+ # note that selected experts will be the same for all orders:
99
+ # softmax doesnt change top-k, but the weights are different
100
+ if self.softmax_order == "softmax_topk":
101
+ all_probs = F.softmax(router_logits, dim=1)
102
+ weights, selected_experts = torch.topk(all_probs, self.top_k)
103
+ elif self.softmax_order == "topk_softmax":
104
+ weights, selected_experts = torch.topk(router_logits, self.top_k)
105
+ weights = F.softmax(weights, dim=-1)
106
+ else:
107
+ raise ValueError(f"Unknown softmax_order: {self.softmax_order}")
108
+
109
+ results = torch.zeros_like(inputs_squashed)
110
+ # naive looping over experts
111
+ for i, expert in enumerate(self.experts):
112
+ batch_idx, nth_expert = torch.where(selected_experts == i)
113
+ output, _ = expert(inputs_squashed[batch_idx])
114
+ results[batch_idx] += weights[batch_idx, nth_expert, None] * output
115
+
116
+ # return results and router logits (for aux loss calculation later)
117
+ return results.view_as(inputs), {
118
+ "router_logits": router_logits,
119
+ "selected_experts": selected_experts,
120
+ }
121
+
122
+
123
+ class TimeDependantMoE(nn.Module):
124
+ def __init__(self, config, mlp):
125
+ super().__init__()
126
+ self._num_experts = config.moe_num_experts
127
+ self._mask_moe = MaskedMoE(config, mlp)
128
+
129
+ def forward(self, x, date):
130
+ mask_date = torch.zeros(x.shape[0], self._num_experts).to(x.device)
131
+ range_tensor = torch.arange(self._num_experts).unsqueeze(0).to(x.device)
132
+ mask_date = (range_tensor < date.unsqueeze(1)).float()
133
+ return self._mask_moe(x, mask_date)