import json
import torch
import torch.nn as nn
import torch.nn.functional as F
# Hide lines below until Lab 5
import numpy as np
# Hide lines above until Lab 5
from .base import BaseLitModel
from .util import dialog_f1_eval, f1_eval, compute_f1, acc, f1_score
from transformers.optimization import get_linear_schedule_with_warmup
from functools import partial
[docs]def mask_hook(grad_input, st, ed):
mask = torch.zeros((grad_input.shape[0], 1)).type_as(grad_input)
mask[st: ed] += 1.0 # 只优化id为1~8的token
# for the speaker unused token12
mask[1:3] += 1.0
return grad_input * mask
[docs]def multilabel_categorical_crossentropy(y_pred, y_true):
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return (neg_loss + pos_loss).mean()
[docs]class BertLitModel(BaseLitModel):
"""
use AutoModelForMaskedLM, and select the output by another layer in the lit model
"""
def __init__(self, model, device ,args, tokenizer):
super().__init__(model, device, args)
self.tokenizer = tokenizer
with open(f"{args.data_dir}/rel2id.json","r") as file:
rel2id = json.load(file)
Na_num = 0
for k, v in rel2id.items():
if k == "NA" or k == "no_relation" or k == "Other":
Na_num = v
break
num_relation = len(rel2id)
# init loss function
self.loss_fn = multilabel_categorical_crossentropy if "dialogue" in args.data_dir else nn.CrossEntropyLoss()
# ignore the no_relation class to compute the f1 score
self.eval_fn = f1_eval if "dialogue" in args.data_dir else partial(f1_score, rel_num=num_relation, na_num=Na_num)
self.best_f1 = 0
self.t_lambda = args.t_lambda
self.label_st_id = tokenizer("[class1]", add_special_tokens=False)['input_ids'][0]
self._init_label_word()
def _init_label_word(self):
args = self.args
# ./dataset/dataset_name
dataset_name = args.data_dir.split("/")[1]
model_name_or_path = args.model_name_or_path.split("/")[-1]
label_path = f"data/{model_name_or_path}.pt"
# [num_labels, num_tokens], ignore the unanswerable
if "dialogue" in args.data_dir:
label_word_idx = torch.load(label_path)[:-1]
else:
label_word_idx = torch.load(label_path)
num_labels = len(label_word_idx)
self.cur_model.resize_token_embeddings(len(self.tokenizer))
with torch.no_grad():
word_embeddings = self.cur_model.get_input_embeddings()
continous_label_word = [a[0] for a in self.tokenizer([f"[class{i}]" for i in range(1, num_labels+1)], add_special_tokens=False)['input_ids']]
for i, idx in enumerate(label_word_idx):
word_embeddings.weight[continous_label_word[i]] = torch.mean(word_embeddings.weight[idx], dim=0)
# word_embeddings.weight[continous_label_word[i]] = self.relation_embedding[i]
so_word = [a[0] for a in self.tokenizer(["[obj]","[sub]"], add_special_tokens=False)['input_ids']]
meaning_word = [a[0] for a in self.tokenizer(["person","organization", "location", "date", "country"], add_special_tokens=False)['input_ids']]
for i, idx in enumerate(so_word):
word_embeddings.weight[so_word[i]] = torch.mean(word_embeddings.weight[meaning_word], dim=0)
assert torch.equal(self.cur_model.get_input_embeddings().weight, word_embeddings.weight)
assert torch.equal(self.cur_model.get_input_embeddings().weight, self.cur_model.get_output_embeddings().weight)
self.word2label = continous_label_word # a continous list
[docs] def forward(self, x):
return self.model(x)
[docs] def training_step(self, batch, batch_idx): # pylint: disable=unused-argument
input_ids, attention_mask, token_type_ids , labels, so = batch
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
token_type_ids = token_type_ids.to(self.device)
labels = labels.to(self.device)
so = so.to(self.device)
result = self.model(input_ids, attention_mask, token_type_ids, return_dict=True, output_hidden_states=True)
logits = result.logits
output_embedding = result.hidden_states[-1]
logits = self.pvp(logits, input_ids)
loss = self.loss_fn(logits, labels) + self.t_lambda * self.ke_loss(output_embedding, labels, so)
#print("Train/loss: ", loss)
return loss
[docs] def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
input_ids, attention_mask, token_type_ids , labels, _ = batch
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
token_type_ids = token_type_ids.to(self.device)
labels = labels.to(self.device)
logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits
logits = self.pvp(logits, input_ids)
loss = self.loss_fn(logits, labels)
#print("Eval/loss: ", loss)
return {"loss": loss, "eval_logits": logits.detach().cpu().numpy(), "eval_labels": labels.detach().cpu().numpy()}
[docs] def validation_epoch_end(self, outputs):
logits = np.concatenate([o["eval_logits"] for o in outputs])
labels = np.concatenate([o["eval_labels"] for o in outputs])
f1 = self.eval_fn(logits, labels)['f1']
#print("Eval/f1: ", f1)
best_f1 = -1
if f1 > self.best_f1:
self.best_f1 = f1
best_f1 = self.best_f1
#print("Eval/best_f1: ", self.best_f1)
return f1, best_f1, self.best_f1
[docs] def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
input_ids, attention_mask, token_type_ids , labels, _ = batch
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
token_type_ids = token_type_ids.to(self.device)
labels = labels.to(self.device)
logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits
logits = self.pvp(logits, input_ids)
return {"test_logits": logits.detach().cpu().numpy(), "test_labels": labels.detach().cpu().numpy()}
[docs] def test_epoch_end(self, outputs):
logits = np.concatenate([o["test_logits"] for o in outputs])
labels = np.concatenate([o["test_labels"] for o in outputs])
f1 = self.eval_fn(logits, labels)['f1']
#print("Test/f1: ", f1)
return f1
[docs] @staticmethod
def add_to_argparse(parser):
BaseLitModel.add_to_argparse(parser)
parser.add_argument("--t_lambda", type=float, default=0.01, help="")
return parser
[docs] def pvp(self, logits, input_ids):
# convert the [batch_size, seq_len, vocab_size] => [batch_size, num_labels]
#! hard coded
_, mask_idx = (input_ids == 103).nonzero(as_tuple=True)
bs = input_ids.shape[0]
mask_output = logits[torch.arange(bs), mask_idx]
assert mask_idx.shape[0] == bs, "only one mask in sequence!"
final_output = mask_output[:,self.word2label]
return final_output
[docs] def ke_loss(self, logits, labels, so):
subject_embedding = []
object_embedding = []
bsz = logits.shape[0]
for i in range(bsz):
subject_embedding.append(torch.mean(logits[i, so[i][0]:so[i][1]], dim=0))
object_embedding.append(torch.mean(logits[i, so[i][2]:so[i][3]], dim=0))
subject_embedding = torch.stack(subject_embedding)
object_embedding = torch.stack(object_embedding)
# trick , the relation ids is concated,
relation_embedding = self.cur_model.get_output_embeddings().weight[labels+self.label_st_id]
loss = torch.norm(subject_embedding + relation_embedding - object_embedding, p=2)
return loss