Source code for deepke.relation_extraction.standard.tools.loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

[docs]class LabelSmoothSoftmaxCEV1(nn.Module): def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100): super(LabelSmoothSoftmaxCEV1, self).__init__() self.lb_smooth = lb_smooth self.reduction = reduction self.lb_ignore = ignore_index self.log_softmax = nn.LogSoftmax(dim=1)
[docs] def forward(self, logits, label): logits = logits.float() # use fp32 to avoid nan with torch.no_grad(): num_classes = logits.size(1) label = label.clone().detach() ignore = label.eq(self.lb_ignore) n_valid = ignore.eq(0).sum() label[ignore] = 0 lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes lb_one_hot = torch.empty_like(logits).fill_( lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach() logs = self.log_softmax(logits) loss = -torch.sum(logs * lb_one_hot, dim=1) loss[ignore] = 0 if self.reduction == 'mean': loss = loss.sum() / n_valid if self.reduction == 'sum': loss = loss.sum() return loss
[docs]def taylor_softmax_v1(x, dim=1, n=4, use_log=False): assert n % 2 == 0 and n > 0 fn = torch.ones_like(x) denor = 1. for i in range(1, n + 1): denor *= i fn = fn + x.pow(i) / denor out = fn / fn.sum(dim=dim, keepdims=True) if use_log: out = out.log() return out
[docs]class LogTaylorSoftmaxV1(nn.Module): def __init__(self, dim=1, n=2): super(LogTaylorSoftmaxV1, self).__init__() assert n % 2 == 0 self.dim = dim self.n = n
[docs] def forward(self, x): return taylor_softmax_v1(x, self.dim, self.n, use_log=True)
[docs]class TaylorCrossEntropyLossV1(nn.Module): def __init__(self, n=2, ignore_index=-1, reduction='mean'): super(TaylorCrossEntropyLossV1, self).__init__() assert n % 2 == 0 self.taylor_softmax = LogTaylorSoftmaxV1(dim=1, n=n) self.reduction = reduction self.ignore_index = ignore_index
[docs] def forward(self, logits, labels): log_probs = self.taylor_softmax(logits) loss = F.nll_loss(log_probs, labels, reduction=self.reduction, ignore_index=self.ignore_index) return loss
[docs]class FocalLoss(nn.Module): def __init__(self, gamma=0, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) self.size_average = size_average
[docs] def forward(self, input, target): if input.dim()>2: input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W input = input.transpose(1,2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1,1) logpt = F.log_softmax(input,dim=1) logpt = logpt.gather(1,target) logpt = logpt.view(-1) pt = Variable(logpt.data.exp()) if self.alpha is not None: if self.alpha.type()!=input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0,target.data.view(-1)) logpt = logpt * Variable(at) loss = -1 * (1-pt)**self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum()