|
from transformers import PreTrainedModel |
|
from .configuration import MoEGPTConfig |
|
|
|
import math |
|
import inspect |
|
|
|
import tiktoken |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
|
|
from .moe import ( |
|
|
|
MaskedMoE, |
|
TimeDependantMoE, |
|
MoE, |
|
) |
|
|
|
from .aux_losses import ( |
|
entropy_reg, |
|
load_balancing_loss, |
|
router_z_loss, |
|
) |
|
|
|
class Output: |
|
def __init__(self, logits, loss=None, aux_losses=None, router_logits=None): |
|
self.logits = logits |
|
self.loss = loss |
|
self.aux_losses = aux_losses |
|
self.router_logits = router_logits |
|
|
|
def __repr__(self): |
|
return f"Output(logits={self.logits}, loss={self.loss}, aux_losses={self.aux_losses}, router_logits={self.router_logits})" |
|
|
|
class LayerNorm(nn.Module): |
|
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" |
|
|
|
def __init__(self, ndim, bias): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(ndim)) |
|
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None |
|
|
|
def forward(self, input): |
|
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) |
|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
assert config.n_embd % config.n_head == 0 |
|
|
|
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
|
|
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
self.n_head = config.n_head |
|
self.n_embd = config.n_embd |
|
self.dropout = config.dropout |
|
|
|
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") |
|
if not self.flash: |
|
print( |
|
"WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" |
|
) |
|
|
|
self.register_buffer( |
|
"bias", |
|
torch.tril( |
|
torch.ones(config.sequence_length, config.sequence_length) |
|
).view(1, 1, config.sequence_length, config.sequence_length), |
|
) |
|
|
|
def forward(self, x): |
|
|
|
( |
|
B, |
|
T, |
|
C, |
|
) = x.size() |
|
|
|
|
|
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
|
|
|
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
|
|
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
|
|
|
|
|
if self.flash: |
|
|
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True |
|
) |
|
else: |
|
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
|
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) |
|
att = F.softmax(att, dim=-1) |
|
att = self.attn_dropout(att) |
|
y = att @ v |
|
y = ( |
|
y.transpose(1, 2).contiguous().view(B, T, C) |
|
) |
|
|
|
|
|
y = self.resid_dropout(self.c_proj(y)) |
|
return y |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dim_exp_factor = int(config.mlp_dim_exp_factor * 4) |
|
|
|
self.c_fc = nn.Linear( |
|
config.n_embd, self.dim_exp_factor * config.n_embd, bias=config.bias |
|
) |
|
self.c_proj = nn.Linear( |
|
self.dim_exp_factor * config.n_embd, config.n_embd, bias=config.bias |
|
) |
|
self.dropout = nn.Dropout(config.dropout) |
|
self.activation = nn.GELU() |
|
|
|
def forward(self, x): |
|
x = self.c_fc(x) |
|
x = self.activation(x) |
|
x = self.c_proj(x) |
|
x = self.dropout(x) |
|
|
|
return x, {} |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) |
|
self.attn = CausalSelfAttention(config) |
|
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) |
|
self.moe_config = config.moe_routing |
|
if config.moe: |
|
if config.moe_routing == "standard_gating": |
|
self.mlp = MoE(config, MLP) |
|
elif config.moe_routing == "masked": |
|
self.mlp = TimeDependantMoE(config, MLP) |
|
|
|
|
|
else: |
|
raise ValueError(f"Unknown routing: {config.routing}") |
|
else: |
|
self.mlp = MLP(config) |
|
|
|
def forward(self, x, date, *args, **kwargs): |
|
x = x + self.attn(self.ln_1(x, *args, **kwargs)) |
|
if self.moe_config == "masked": |
|
x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs), date) |
|
else: |
|
x_, logits_and_experts = self.mlp(self.ln_2(x, *args, **kwargs)) |
|
x = x + x_ |
|
return x, logits_and_experts |
|
|
|
|
|
class MoEGPTForCausalLM(PreTrainedModel): |
|
config_class = MoEGPTConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
assert config.vocab_size is not None |
|
assert config.sequence_length is not None |
|
self.config = config |
|
self.tokenizer = tiktoken.get_encoding("gpt2") |
|
|
|
self.transformer = nn.ModuleDict( |
|
dict( |
|
wte=nn.Embedding(config.vocab_size, config.n_embd), |
|
wpe=nn.Embedding(config.sequence_length, config.n_embd), |
|
drop=nn.Dropout(config.dropout), |
|
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
|
ln_f=LayerNorm(config.n_embd, bias=config.bias), |
|
) |
|
) |
|
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
|
self.transformer.wte.weight = ( |
|
self.lm_head.weight |
|
) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
for pn, p in self.named_parameters(): |
|
if pn.endswith("c_proj.weight"): |
|
torch.nn.init.normal_( |
|
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) |
|
) |
|
if pn.endswith("router.weight"): |
|
|
|
with torch.no_grad(): |
|
dim = 1 if config.moe_routing == "standard_gating" else 0 |
|
std = p.std() |
|
p.div_(p.sum(dim=dim, keepdim=True)) |
|
p.mul_(std / p.std()) |
|
|
|
def get_router_losses(self, logits, selected_experts, eval=False): |
|
|
|
|
|
if eval: |
|
return { |
|
"moe_entropy_loss": entropy_reg(logits), |
|
"moe_aux_loss": load_balancing_loss(logits, selected_experts), |
|
"moe_z_loss": router_z_loss(logits), |
|
} |
|
if self.config.moe_router_loss == "entropy": |
|
return { |
|
"moe_entropy_loss": entropy_reg(logits), |
|
} |
|
elif self.config.moe_router_loss == "load_balancing_only": |
|
return { |
|
"moe_aux_loss": load_balancing_loss(logits, selected_experts), |
|
} |
|
elif self.config.moe_router_loss == "load_balancing_z_loss": |
|
return { |
|
"moe_aux_loss": load_balancing_loss(logits, selected_experts), |
|
"moe_z_loss": router_z_loss(logits), |
|
} |
|
return {} |
|
|
|
def get_num_params(self, non_embedding=True): |
|
""" |
|
Return the number of parameters in the model. |
|
For non-embedding count (default), the position embeddings get subtracted. |
|
The token embeddings would too, except due to the parameter sharing these |
|
params are actually used as weights in the final layer, so we include them. |
|
""" |
|
n_params = sum(p.numel() for p in self.parameters()) |
|
if non_embedding: |
|
n_params -= self.transformer.wpe.weight.numel() |
|
return n_params |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
torch.nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
def forward(self, idx, date=None, targets=None, get_logits=True, moe=False): |
|
device = idx.device |
|
b, t = idx.size() |
|
assert ( |
|
t <= self.config.sequence_length |
|
), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" |
|
|
|
if date is None: |
|
|
|
date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0) |
|
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) |
|
|
|
|
|
tok_emb = self.transformer.wte(idx) |
|
pos_emb = self.transformer.wpe( |
|
pos |
|
) |
|
x = self.transformer.drop(tok_emb + pos_emb) |
|
|
|
|
|
router_logits = [] |
|
|
|
experts = [] |
|
|
|
|
|
for block in self.transformer.h: |
|
x, logits_and_experts = block(x, date) |
|
if len(logits_and_experts) > 0: |
|
router_logits.append(logits_and_experts["router_logits"]) |
|
experts.append(logits_and_experts["selected_experts"]) |
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
aux_losses = {} |
|
|
|
if targets is not None: |
|
|
|
logits = self.lm_head(x) |
|
loss = F.cross_entropy( |
|
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 |
|
) |
|
if moe and (self.config.moe_routing == "standard_gating" or self.config.moe_routing == "masked"): |
|
|
|
for logit, expert_choice in zip(router_logits, experts): |
|
router_losses = self.get_router_losses( |
|
logit, expert_choice, eval=not self.training |
|
) |
|
for k, v in router_losses.items(): |
|
aux_losses[k] = aux_losses.get(k, 0.0) + v |
|
if self.training: |
|
loss += ( |
|
v |
|
* getattr(self.config, k + "_factor") |
|
/ self.config.n_layer |
|
) |
|
else: |
|
|
|
logits = self.lm_head( |
|
|
|
x |
|
) |
|
loss = None |
|
logits = logits if get_logits else None |
|
router_logits = ( |
|
torch.stack(router_logits, dim=0) if len(router_logits) > 0 else None |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return Output(logits = logits, loss = loss, aux_losses = aux_losses, router_logits = router_logits) |
|
|
|
def crop_sequence_length(self, sequence_length): |
|
|
|
|
|
|
|
assert sequence_length <= self.config.sequence_length |
|
self.config.sequence_length = sequence_length |
|
self.transformer.wpe.weight = nn.Parameter( |
|
self.transformer.wpe.weight[:sequence_length] |
|
) |
|
for block in self.transformer.h: |
|
block.attn.bias = block.attn.bias[:, :, :sequence_length, :sequence_length] |
|
|
|
|
|
def get_parameter_group_specs(self): |
|
""" |
|
This long function is unfortunately doing something very simple and is being very defensive: |
|
We are separating out all parameters of the model into two buckets: those that will experience |
|
weight decay for regularization and those that won't (biases, and layernorm/embedding weights). |
|
We are then returning the PyTorch optimizer object. |
|
""" |
|
|
|
|
|
decay = set() |
|
no_decay = set() |
|
whitelist_weight_modules = (torch.nn.Linear,) |
|
|
|
BLACKLIST_WEIGHT_MODULES = ( |
|
torch.nn.LayerNorm, |
|
LayerNorm, |
|
torch.nn.Embedding, |
|
) |
|
|
|
for mn, m in self.named_modules(): |
|
for pn, p in m.named_parameters(): |
|
fpn = "%s.%s" % (mn, pn) if mn else pn |
|
|
|
|
|
|
|
if pn.endswith("bias"): |
|
|
|
no_decay.add(fpn) |
|
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): |
|
|
|
decay.add(fpn) |
|
elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES): |
|
|
|
no_decay.add(fpn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decay.remove("lm_head.weight") |
|
|
|
|
|
param_dict = {pn: p for pn, p in self.named_parameters()} |
|
inter_params = decay & no_decay |
|
union_params = decay | no_decay |
|
assert ( |
|
len(inter_params) == 0 |
|
), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) |
|
assert ( |
|
len(param_dict.keys() - union_params) == 0 |
|
), "parameters %s were not separated into either decay/no_decay set!" % ( |
|
str(param_dict.keys() - union_params), |
|
) |
|
|
|
|
|
return [ |
|
{"params": sorted(list(decay))}, |
|
{"params": sorted(list(no_decay)), "weight_decay": 0.0}, |
|
] |
|
|
|
@torch.no_grad() |
|
def generate(self, input_ids, max_new_tokens, date = None, temperature=1.0, top_k=None): |
|
""" |
|
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
|
the sequence max_new_tokens times, feeding the predictions back into the model each time. |
|
Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
|
""" |
|
idx = input_ids |
|
for _ in range(max_new_tokens): |
|
|
|
idx_cond = ( |
|
idx |
|
if idx.size(1) <= self.config.sequence_length |
|
else idx[:, -self.config.sequence_length :] |
|
) |
|
|
|
logits = self(idx_cond, date, get_logits=True).logits |
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = -float("Inf") |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
return idx |
|
|
|
@torch.no_grad() |
|
def generate_from_string(self, in_str, max_new_tokens, date = None, temperature=1.0, top_k=None): |
|
idx = ( |
|
torch.tensor( |
|
self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"}) |
|
) |
|
.view(1, -1) |
|
.to(self.lm_head.weight.device) |
|
) |
|
out_idx = ( |
|
self.generate(idx, max_new_tokens, date, temperature, top_k) |
|
.view(-1) |
|
.to("cpu") |
|
.numpy() |
|
) |
|
return self.tokenizer.decode(out_idx) |
|
|
|
|
|
|