import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
[docs]class LabelSmoothSoftmaxCEV1(nn.Module):
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothSoftmaxCEV1, self).__init__()
self.lb_smooth = lb_smooth
self.reduction = reduction
self.lb_ignore = ignore_index
self.log_softmax = nn.LogSoftmax(dim=1)
[docs] def forward(self, logits, label):
logits = logits.float() # use fp32 to avoid nan
with torch.no_grad():
num_classes = logits.size(1)
label = label.clone().detach()
ignore = label.eq(self.lb_ignore)
n_valid = ignore.eq(0).sum()
label[ignore] = 0
lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
lb_one_hot = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
logs = self.log_softmax(logits)
loss = -torch.sum(logs * lb_one_hot, dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
if self.reduction == 'sum':
loss = loss.sum()
return loss
[docs]def taylor_softmax_v1(x, dim=1, n=4, use_log=False):
assert n % 2 == 0 and n > 0
fn = torch.ones_like(x)
denor = 1.
for i in range(1, n + 1):
denor *= i
fn = fn + x.pow(i) / denor
out = fn / fn.sum(dim=dim, keepdims=True)
if use_log: out = out.log()
return out
[docs]class LogTaylorSoftmaxV1(nn.Module):
def __init__(self, dim=1, n=2):
super(LogTaylorSoftmaxV1, self).__init__()
assert n % 2 == 0
self.dim = dim
self.n = n
[docs] def forward(self, x):
return taylor_softmax_v1(x, self.dim, self.n, use_log=True)
[docs]class TaylorCrossEntropyLossV1(nn.Module):
def __init__(self, n=2, ignore_index=-1, reduction='mean'):
super(TaylorCrossEntropyLossV1, self).__init__()
assert n % 2 == 0
self.taylor_softmax = LogTaylorSoftmaxV1(dim=1, n=n)
self.reduction = reduction
self.ignore_index = ignore_index
[docs] def forward(self, logits, labels):
log_probs = self.taylor_softmax(logits)
loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
ignore_index=self.ignore_index)
return loss
[docs]class FocalLoss(nn.Module):
def __init__(self, gamma=0, alpha=None, size_average=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int)):
self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list):
self.alpha = torch.Tensor(alpha)
self.size_average = size_average
[docs] def forward(self, input, target):
if input.dim()>2:
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
target = target.view(-1,1)
logpt = F.log_softmax(input,dim=1)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp())
if self.alpha is not None:
if self.alpha.type()!=input.data.type():
self.alpha = self.alpha.type_as(input.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * Variable(at)
loss = -1 * (1-pt)**self.gamma * logpt
if self.size_average:
return loss.mean()
else:
return loss.sum()