import torch
import torch.nn as nn
from opt_einsum import contract
from .losses import ATLoss
import torch.nn.functional as F
from .module import AttentionUNet
from .prepro import process_long_input
[docs]class DocREModel(nn.Module):
def __init__(self, config, args, model, emb_size=768, block_size=64, num_labels=-1):
super().__init__()
self.config = config
self.bert_model = model
self.hidden_size = config.hidden_size
self.loss_fnt = ATLoss()
self.head_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
self.tail_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
# self.head_extractor = nn.Linear(1 * config.hidden_size , emb_size)
# self.tail_extractor = nn.Linear(1 * config.hidden_size , emb_size)
self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
self.emb_size = emb_size
self.block_size = block_size
self.num_labels = num_labels
self.bertdrop = nn.Dropout(0.6)
self.unet_in_dim = args.unet_in_dim
self.unet_out_dim = args.unet_in_dim
self.liner = nn.Linear(config.hidden_size, args.unet_in_dim)
self.min_height = args.max_height
self.channel_type = args.channel_type
self.segmentation_net = AttentionUNet(input_channels=args.unet_in_dim,
class_number=args.unet_out_dim,
down_channel=args.down_dim)
[docs] def encode(self, input_ids, attention_mask,entity_pos):
config = self.config
if config.transformer_type == "bert":
start_tokens = [config.cls_token_id]
end_tokens = [config.sep_token_id]
elif config.transformer_type == "roberta":
start_tokens = [config.cls_token_id]
end_tokens = [config.sep_token_id, config.sep_token_id]
sequence_output, attention = process_long_input(self.bert_model, input_ids, attention_mask, start_tokens, end_tokens)
return sequence_output, attention
[docs] def get_hrt(self, sequence_output, attention, entity_pos, hts):
offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
bs, h, _, c = attention.size()
# ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
hss, tss, rss = [], [], []
entity_es = []
entity_as = []
for i in range(len(entity_pos)):
entity_embs, entity_atts = [], []
for entity_num, e in enumerate(entity_pos[i]):
if len(e) > 1:
e_emb, e_att = [], []
for start, end in e:
if start + offset < c:
# In case the entity mention is truncated due to limited max seq length.
e_emb.append(sequence_output[i, start + offset])
e_att.append(attention[i, :, start + offset])
if len(e_emb) > 0:
e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
e_att = torch.stack(e_att, dim=0).mean(0)
else:
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
e_att = torch.zeros(h, c).to(attention)
else:
start, end = e[0]
if start + offset < c:
e_emb = sequence_output[i, start + offset]
e_att = attention[i, :, start + offset]
else:
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
e_att = torch.zeros(h, c).to(attention)
entity_embs.append(e_emb)
entity_atts.append(e_att)
for _ in range(self.min_height-entity_num-1):
entity_atts.append(e_att)
entity_embs = torch.stack(entity_embs, dim=0) # [n_e, d]
entity_atts = torch.stack(entity_atts, dim=0) # [n_e, h, seq_len]
entity_es.append(entity_embs)
entity_as.append(entity_atts)
ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
hss.append(hs)
tss.append(ts)
hss = torch.cat(hss, dim=0)
tss = torch.cat(tss, dim=0)
return hss, tss, entity_es, entity_as
[docs] def get_mask(self, ents, bs, ne, run_device):
ent_mask = torch.zeros(bs, ne, device=run_device)
rel_mask = torch.zeros(bs, ne, ne, device=run_device)
for _b in range(bs):
ent_mask[_b, :len(ents[_b])] = 1
rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1
return ent_mask, rel_mask
[docs] def get_ht(self, rel_enco, hts):
htss = []
for i in range(len(hts)):
ht_index = hts[i]
for (h_index, t_index) in ht_index:
htss.append(rel_enco[i,h_index,t_index])
htss = torch.stack(htss,dim=0)
return htss
[docs] def get_channel_map(self, sequence_output, entity_as):
# sequence_output = sequence_output.to('cpu')
# attention = attention.to('cpu')
bs,_,d = sequence_output.size()
# ne = max([len(x) for x in entity_as]) # 本次bs中的最大实体数
ne = self.min_height
index_pair = []
for i in range(ne):
tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
index_pair.append(tmp)
index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
map_rss = []
for b in range(bs):
entity_atts = entity_as[b]
h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
ht_att = (h_att * t_att).mean(1)
ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
rs = contract("ld,rl->rd", sequence_output[b], ht_att)
map_rss.append(rs)
map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
return map_rss
[docs] def forward(self,
input_ids=None,
attention_mask=None,
labels=None,
entity_pos=None,
hts=None,
instance_mask=None,
):
sequence_output, attention = self.encode(input_ids, attention_mask,entity_pos)
bs, sequen_len, d = sequence_output.shape
run_device = sequence_output.device.index
ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
ent_mask, rel_mask = self.get_mask(entity_pos, bs, ne, run_device)
# get hs, ts and entity_embs >> entity_rs
hs, ts, entity_embs, entity_as = self.get_hrt(sequence_output, attention, entity_pos, hts)
if self.channel_type == 'context-based':
feature_map = self.get_channel_map(sequence_output, entity_as)
##print('feature_map:', feature_map.shape)
attn_input = self.liner(feature_map).permute(0, 3, 1, 2).contiguous()
else:
raise Exception("channel_type must be specify correctly")
attn_map = self.segmentation_net(attn_input)
h_t = self.get_ht (attn_map, hts)
hs = torch.tanh(self.head_extractor(torch.cat([hs, h_t], dim=1)))
ts = torch.tanh(self.tail_extractor(torch.cat([ts, h_t], dim=1)))
b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size)
b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size)
bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
logits = self.bilinear(bl)
output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels))
if labels is not None:
labels = [torch.tensor(label) for label in labels]
labels = torch.cat(labels, dim=0).to(logits)
loss = self.loss_fnt(logits.float(), labels.float())
output = (loss.to(sequence_output), output)
return output