Source code for deepke.name_entity_re.few_shot.utils.util

import torch
import numpy as np
import random
from torch import nn
import torch.nn.functional as F
from transformers import BartModel, BartTokenizer
import os

[docs]def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens): """when initial added tokens, use their averge token emebddings Args: tokenizer (BartTokenizer): [description] bart_model (BartModel): [description] bart_name ([type]): [description] num_tokens ([type]): [description] Raises: RuntimeError: [description] Returns: [type]: [description] """ _tokenizer = BartTokenizer.from_pretrained(bart_name) for token in tokenizer.unique_no_split_tokens: if token[:2] == '<<': # 特殊字符 index = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(token)) if len(index)>1: raise RuntimeError(f"{token} wrong split") else: index = index[0] assert index>=num_tokens, (index, num_tokens, token) indexes = _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(token[2:-2])) embed = bart_model.encoder.embed_tokens.weight.data[indexes[0]] for i in indexes[1:]: embed += bart_model.decoder.embed_tokens.weight.data[i] embed /= len(indexes) bart_model.decoder.embed_tokens.weight.data[index] = embed return bart_model
[docs]def seq_to_mask(seq_len, max_len): """[get attention mask with sequence length] Args: seq_len ([torch.tensor]): [shape: bsz, each sequence length in a batch] """ max_len = int(max_len) if max_len else seq_len.max().long() cast_seq = torch.arange(max_len).expand(seq_len.size(0), -1).to(seq_len) mask = cast_seq.lt(seq_len.unsqueeze(1)) return mask
[docs]def get_loss(tgt_tokens, tgt_seq_len, pred): """ :param tgt_tokens: bsz x max_len, 包含了的[sos, token, eos] :param pred: bsz x max_len-1 x vocab_size :return: """ tgt_seq_len = tgt_seq_len - 1 mask = seq_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1) - 1).eq(0) tgt_tokens = tgt_tokens[:, 1:].masked_fill(mask, -100) loss = F.cross_entropy(target=tgt_tokens, input=pred.transpose(1, 2)) return loss
[docs]def get_model_device(model): assert isinstance(model, nn.Module) parameters = list(model.parameters()) if len(parameters) == 0: return None else: return parameters[0].device
[docs]def set_seed(seed=2021): """sets random seed""" torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True np.random.seed(seed) random.seed(seed)
[docs]def convert_preds_to_outputs(preds, raw_words, mapping, tokenizer): """convet model predicitons to BIO outputs Args: preds ([torch.Tensor]): [prompt model predictions, (bsz x seq_len x labels)] raw_words ([List]): [source raw words] mapping ([dict]): [map entity labels to <<>>] tokenizer : [BartTokenizer] Returns: [outputs (List)]: [each item length equal to raw_words, BIO format.] """ id2label = list(mapping.keys()) pred_eos_index = preds.flip(dims=[1]).eq(1).cumsum(dim=1).long() preds = preds[:, 1:] pred_seq_len = pred_eos_index.flip(dims=[1]).eq(pred_eos_index[:, -1:]).sum(dim=1) # bsz pred_seq_len = (pred_seq_len - 2).tolist() word_start_index = len(mapping) + 2 outputs = [] for i, pred_item in enumerate(preds.tolist()): pred_item = pred_item[:pred_seq_len[i]] # single sentence prediction pairs, cur_pair = [], [] if len(pred_item): # this sentence prediciton= is not null for idx in pred_item: if idx < word_start_index: # is entity if len(cur_pair) > 0: # assert word[i] < word[i+1] if all([cur_pair[i] < cur_pair[i + 1] for i in range(len(cur_pair) - 1)]): pairs.append(tuple(cur_pair + [idx])) # add valid words and current entity id cur_pair = [] # clear word pairs else: # is word cur_pair.append(idx) # add word id to word pairs raw_words_item = raw_words[i] cum_lens = [1] start_idx = 1 for word in raw_words_item: start_idx += len(tokenizer.tokenize(word, add_prefix_space=True)) cum_lens.append(start_idx) cum_lens.append(start_idx+1) output = ['O' for _ in range(len(raw_words_item))] # pairs: List[(word id, ... , entity id), (...), ...] for pair in pairs: # (word id, ... , entity id) entity = pair[-1] words = [] for word in pair[:-1]: if word-word_start_index in cum_lens: words.append(cum_lens.index(word-word_start_index)) if len(words) == 0: continue start_idx = words[0] end_idx = words[-1] output[start_idx] = f'B-{id2label[entity-2]}' for _ in range(start_idx+1, end_idx+1): output[_] = f'I-{id2label[entity-2]}' outputs.append(output) return outputs
[docs]def write_predictions(path, texts, labels): """[write model predictions to path (conll format)] Args: path ([str]): [save path] texts ([List]): [raw texts] labels ([List]): [predict labels] """ print(len(texts), len(labels)) assert len(texts) == len(labels) if not os.path.exists(path): os.system(r"touch {}".format(path)) with open(path, "w", encoding="utf-8") as f: f.writelines("-DOCSTART- O\n\n") for i in range(len(texts)): for j in range(len(texts[i])): f.writelines("{}\t{}\n".format(texts[i][j], labels[i][j])) f.writelines("\n")