|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def log_mean(x, dim): |
|
return torch.logsumexp(x, dim=dim) - torch.log( |
|
torch.tensor(x.shape[dim], dtype=torch.float32) |
|
) |
|
|
|
|
|
def entropy_reg(logits: torch.Tensor, mean_over_batch: bool = True): |
|
"""Entropy regularization for the router.""" |
|
|
|
entropy_l = lambda l: -(l * l.exp()).sum(-1) |
|
|
|
|
|
logprobs = F.log_softmax(logits, dim=-1) |
|
if mean_over_batch: |
|
|
|
logprobs = log_mean(logprobs, 0) |
|
|
|
return -entropy_l(logprobs).mean() |
|
|
|
|
|
|
|
|
|
def load_balancing_loss(logits: torch.Tensor, expert_indices: torch.Tensor) -> float: |
|
"""Computes auxiliary load balancing loss as in Switch Transformer. |
|
|
|
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function |
|
implements the loss function presented in equations (4) - (6). It aims to |
|
penalize those cases where the routing between experts is unbalanced. |
|
|
|
Args: |
|
logits: logits assigned to each expert per token. Shape: |
|
<float32>[batch_size * sequence_length, num_experts]. |
|
expert_indices: <int>[batch_size * sequence_length, num_selected_experts] |
|
indices identifying the top num_selected_experts for a given token. |
|
|
|
Returns: |
|
The auxiliary loss. |
|
""" |
|
|
|
num_token, num_experts = logits.shape |
|
|
|
|
|
expert_mask = F.one_hot(expert_indices, num_experts) |
|
|
|
|
|
expert_mask, _ = torch.max(expert_mask, dim=-2) |
|
|
|
|
|
tokens_per_expert = torch.mean(expert_mask, dim=0, dtype=torch.float32) |
|
|
|
|
|
logprobs = F.log_softmax(logits, dim=-1) |
|
|
|
|
|
logprobs = log_mean(logprobs, dim=0) |
|
router_prob_per_expert = torch.exp(logprobs) |
|
return ( |
|
torch.mean( |
|
tokens_per_expert * router_prob_per_expert, |
|
dtype=torch.float32, |
|
) |
|
* num_experts |
|
) |
|
|
|
|
|
def router_z_loss(router_logits: torch.Tensor) -> float: |
|
"""Compute router z-loss. |
|
|
|
The router z-loss was introduced in Designing Effective Sparse Expert Models |
|
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain |
|
small in an effort to improve stability. |
|
|
|
Args: |
|
router_logits: <float>[batch_size * sequence_length, num_experts] |
|
router logits |
|
|
|
Returns: |
|
Scalar router z-loss. |
|
""" |
|
num_tokens, _ = router_logits.shape |
|
log_z = torch.logsumexp(router_logits, dim=-1) |
|
z_loss = log_z**2 |
|
return torch.sum(z_loss, dtype=torch.float32) / (num_tokens) |
|
|