Source code for deepke.relation_extraction.few_shot.lit_models.transformer

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
# Hide lines below until Lab 5
import numpy as np
# Hide lines above until Lab 5

from .base import BaseLitModel
from .util import dialog_f1_eval, f1_eval, compute_f1, acc, f1_score
from transformers.optimization import get_linear_schedule_with_warmup

from functools import partial


[docs]def mask_hook(grad_input, st, ed): mask = torch.zeros((grad_input.shape[0], 1)).type_as(grad_input) mask[st: ed] += 1.0 # 只优化id为1~8的token # for the speaker unused token12 mask[1:3] += 1.0 return grad_input * mask
[docs]def multilabel_categorical_crossentropy(y_pred, y_true): y_pred = (1 - 2 * y_true) * y_pred y_pred_neg = y_pred - y_true * 1e12 y_pred_pos = y_pred - (1 - y_true) * 1e12 zeros = torch.zeros_like(y_pred[..., :1]) y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) neg_loss = torch.logsumexp(y_pred_neg, dim=-1) pos_loss = torch.logsumexp(y_pred_pos, dim=-1) return (neg_loss + pos_loss).mean()
[docs]class BertLitModel(BaseLitModel): """ use AutoModelForMaskedLM, and select the output by another layer in the lit model """ def __init__(self, model, device ,args, tokenizer): super().__init__(model, device, args) self.tokenizer = tokenizer with open(f"{args.data_dir}/rel2id.json","r") as file: rel2id = json.load(file) Na_num = 0 for k, v in rel2id.items(): if k == "NA" or k == "no_relation" or k == "Other": Na_num = v break num_relation = len(rel2id) # init loss function self.loss_fn = multilabel_categorical_crossentropy if "dialogue" in args.data_dir else nn.CrossEntropyLoss() # ignore the no_relation class to compute the f1 score self.eval_fn = f1_eval if "dialogue" in args.data_dir else partial(f1_score, rel_num=num_relation, na_num=Na_num) self.best_f1 = 0 self.t_lambda = args.t_lambda self.label_st_id = tokenizer("[class1]", add_special_tokens=False)['input_ids'][0] self._init_label_word() def _init_label_word(self): args = self.args # ./dataset/dataset_name dataset_name = args.data_dir.split("/")[1] model_name_or_path = args.model_name_or_path.split("/")[-1] label_path = f"data/{model_name_or_path}.pt" # [num_labels, num_tokens], ignore the unanswerable if "dialogue" in args.data_dir: label_word_idx = torch.load(label_path)[:-1] else: label_word_idx = torch.load(label_path) num_labels = len(label_word_idx) self.cur_model.resize_token_embeddings(len(self.tokenizer)) with torch.no_grad(): word_embeddings = self.cur_model.get_input_embeddings() continous_label_word = [a[0] for a in self.tokenizer([f"[class{i}]" for i in range(1, num_labels+1)], add_special_tokens=False)['input_ids']] for i, idx in enumerate(label_word_idx): word_embeddings.weight[continous_label_word[i]] = torch.mean(word_embeddings.weight[idx], dim=0) # word_embeddings.weight[continous_label_word[i]] = self.relation_embedding[i] so_word = [a[0] for a in self.tokenizer(["[obj]","[sub]"], add_special_tokens=False)['input_ids']] meaning_word = [a[0] for a in self.tokenizer(["person","organization", "location", "date", "country"], add_special_tokens=False)['input_ids']] for i, idx in enumerate(so_word): word_embeddings.weight[so_word[i]] = torch.mean(word_embeddings.weight[meaning_word], dim=0) assert torch.equal(self.cur_model.get_input_embeddings().weight, word_embeddings.weight) assert torch.equal(self.cur_model.get_input_embeddings().weight, self.cur_model.get_output_embeddings().weight) self.word2label = continous_label_word # a continous list
[docs] def forward(self, x): return self.model(x)
[docs] def training_step(self, batch, batch_idx): # pylint: disable=unused-argument input_ids, attention_mask, token_type_ids , labels, so = batch input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) token_type_ids = token_type_ids.to(self.device) labels = labels.to(self.device) so = so.to(self.device) result = self.model(input_ids, attention_mask, token_type_ids, return_dict=True, output_hidden_states=True) logits = result.logits output_embedding = result.hidden_states[-1] logits = self.pvp(logits, input_ids) loss = self.loss_fn(logits, labels) + self.t_lambda * self.ke_loss(output_embedding, labels, so) #print("Train/loss: ", loss) return loss
[docs] def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument input_ids, attention_mask, token_type_ids , labels, _ = batch input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) token_type_ids = token_type_ids.to(self.device) labels = labels.to(self.device) logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits logits = self.pvp(logits, input_ids) loss = self.loss_fn(logits, labels) #print("Eval/loss: ", loss) return {"loss": loss, "eval_logits": logits.detach().cpu().numpy(), "eval_labels": labels.detach().cpu().numpy()}
[docs] def validation_epoch_end(self, outputs): logits = np.concatenate([o["eval_logits"] for o in outputs]) labels = np.concatenate([o["eval_labels"] for o in outputs]) f1 = self.eval_fn(logits, labels)['f1'] #print("Eval/f1: ", f1) best_f1 = -1 if f1 > self.best_f1: self.best_f1 = f1 best_f1 = self.best_f1 #print("Eval/best_f1: ", self.best_f1) return f1, best_f1, self.best_f1
[docs] def test_step(self, batch, batch_idx): # pylint: disable=unused-argument input_ids, attention_mask, token_type_ids , labels, _ = batch input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) token_type_ids = token_type_ids.to(self.device) labels = labels.to(self.device) logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits logits = self.pvp(logits, input_ids) return {"test_logits": logits.detach().cpu().numpy(), "test_labels": labels.detach().cpu().numpy()}
[docs] def test_epoch_end(self, outputs): logits = np.concatenate([o["test_logits"] for o in outputs]) labels = np.concatenate([o["test_labels"] for o in outputs]) f1 = self.eval_fn(logits, labels)['f1'] #print("Test/f1: ", f1) return f1
[docs] @staticmethod def add_to_argparse(parser): BaseLitModel.add_to_argparse(parser) parser.add_argument("--t_lambda", type=float, default=0.01, help="") return parser
[docs] def pvp(self, logits, input_ids): # convert the [batch_size, seq_len, vocab_size] => [batch_size, num_labels] #! hard coded _, mask_idx = (input_ids == 103).nonzero(as_tuple=True) bs = input_ids.shape[0] mask_output = logits[torch.arange(bs), mask_idx] assert mask_idx.shape[0] == bs, "only one mask in sequence!" final_output = mask_output[:,self.word2label] return final_output
[docs] def ke_loss(self, logits, labels, so): subject_embedding = [] object_embedding = [] bsz = logits.shape[0] for i in range(bsz): subject_embedding.append(torch.mean(logits[i, so[i][0]:so[i][1]], dim=0)) object_embedding.append(torch.mean(logits[i, so[i][2]:so[i][3]], dim=0)) subject_embedding = torch.stack(subject_embedding) object_embedding = torch.stack(object_embedding) # trick , the relation ids is concated, relation_embedding = self.cur_model.get_output_embeddings().weight[labels+self.label_st_id] loss = torch.norm(subject_embedding + relation_embedding - object_embedding, p=2) return loss
[docs] def configure_optimizers(self): no_decay_param = ["bias", "LayerNorm.weight"] if not self.args.two_steps: parameters = self.cur_model.named_parameters() else: # cur_model.bert.embeddings.weight parameters = [next(self.cur_model.named_parameters())] # only optimize the embedding parameters optimizer_group_parameters = [ {"params": [p for n, p in parameters if not any(nd in n for nd in no_decay_param)], "weight_decay": self.args.weight_decay}, {"params": [p for n, p in parameters if any(nd in n for nd in no_decay_param)], "weight_decay": 0} ] optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8) return optimizer '''return { "optimizer": optimizer, "lr_scheduler":{ 'scheduler': scheduler, 'interval': 'step', # or 'epoch' 'frequency': 1, } }'''