import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
import torch
from . import BasicModule
from module import Embedding, CNN
from module import Capsule as CapsuleLayer
from utils import seq_len_to_mask, to_one_hot
[docs]class Capsule(BasicModule):
def __init__(self, cfg):
super(Capsule, self).__init__()
if cfg.dim_strategy == 'cat':
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
else:
cfg.in_channels = cfg.word_dim
# capsule config
cfg.input_dim_capsule = cfg.out_channels
cfg.num_capsule = cfg.num_relations
self.num_relations = cfg.num_relations
self.embedding = Embedding(cfg)
self.cnn = CNN(cfg)
self.capsule = CapsuleLayer(cfg)
[docs] def forward(self, x):
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
mask = seq_len_to_mask(lens)
inputs = self.embedding(word, head_pos, tail_pos)
primary, _ = self.cnn(inputs) # 由于长度改变,无法定向mask,不mask可可以,毕竟primary capsule 就是粗粒度的信息
output = self.capsule(primary)
output = output.norm(p=2, dim=-1) # 求得模长再返回值
return output # [B, N]
[docs] def loss(self, predict, target, reduction='mean'):
m_plus, m_minus, loss_lambda = 0.9, 0.1, 0.5
target = to_one_hot(target, self.num_relations)
max_l = (torch.relu(m_plus - predict))**2
max_r = (torch.relu(predict - m_minus))**2
loss = target * max_l + loss_lambda * (1 - target) * max_r
loss = torch.sum(loss, dim=-1)
if reduction == 'sum':
return loss.sum()
else:
# 默认情况为求平均
return loss.mean()