File size: 1,573 Bytes
c857c8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from typing import List, Union
import torch
from torch.nn.functional import cross_entropy
from .constants import IGNORE_INDEX
__all__ = ["soft_cross_entropy"]
def soft_cross_entropy(
outputs: torch.Tensor,
targets: torch.Tensor,
soft_tokens: Union[torch.Tensor, List[int]],
std: float = 1,
ignore_index: int = IGNORE_INDEX,
) -> torch.Tensor:
# Remove last token from outputs and first token from targets
outputs = outputs[..., :-1, :].contiguous()
targets = targets[..., 1:].contiguous()
# Flatten outputs and targets
targets = targets.view(-1)
outputs = outputs.view(targets.size(0), -1)
# Remove outputs and targets with ignore_index
indices = targets != ignore_index
outputs = outputs[indices]
targets = targets[indices]
# Convert soft token IDs to tensor
if isinstance(soft_tokens, list):
soft_tokens = torch.tensor(soft_tokens).to(targets)
# Calculate loss for non-soft tokens
indices = torch.isin(targets, soft_tokens, invert=True)
loss = cross_entropy(outputs[indices], targets[indices], reduction="sum")
# Calculate loss for soft tokens
indices = torch.isin(targets, soft_tokens)
targets_indices = torch.zeros_like(outputs[indices])
for k, target in enumerate(targets[indices]):
dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2))
targets_indices[k][soft_tokens] = dist / dist.sum()
loss += cross_entropy(outputs[indices], targets_indices, reduction="sum")
# Return average loss
return loss / targets.size(0)
|