Source code for deepke.relation_extraction.document.prepro

from tqdm import tqdm
import torch.nn.functional as F
import ujson as json
import os
import pickle
import numpy as np
import torch

[docs]def process_long_input(model, input_ids, attention_mask, start_tokens, end_tokens): # Split the input to 2 overlapping chunks. Now BERT can encode inputs of which the length are up to 1024. n, c = input_ids.size() start_tokens = torch.tensor(start_tokens).to(input_ids) end_tokens = torch.tensor(end_tokens).to(input_ids) len_start = start_tokens.size(0) len_end = end_tokens.size(0) if c <= 512: output = model( input_ids=input_ids, attention_mask=attention_mask, output_attentions=True, ) sequence_output = output[0] attention = output[-1][-1] else: new_input_ids, new_attention_mask, num_seg = [], [], [] seq_len = attention_mask.sum(1).cpu().numpy().astype(np.int32).tolist() for i, l_i in enumerate(seq_len): if l_i <= 512: new_input_ids.append(input_ids[i, :512]) new_attention_mask.append(attention_mask[i, :512]) num_seg.append(1) else: input_ids1 =[input_ids[i, :512 - len_end], end_tokens], dim=-1) input_ids2 =[start_tokens, input_ids[i, (l_i - 512 + len_start): l_i]], dim=-1) attention_mask1 = attention_mask[i, :512] attention_mask2 = attention_mask[i, (l_i - 512): l_i] new_input_ids.extend([input_ids1, input_ids2]) new_attention_mask.extend([attention_mask1, attention_mask2]) num_seg.append(2) input_ids = torch.stack(new_input_ids, dim=0) attention_mask = torch.stack(new_attention_mask, dim=0) output = model( input_ids=input_ids, attention_mask=attention_mask, output_attentions=True, ) sequence_output = output[0] attention = output[-1][-1] i = 0 new_output, new_attention = [], [] for (n_s, l_i) in zip(num_seg, seq_len): if n_s == 1: output = F.pad(sequence_output[i], (0, 0, 0, c - 512)) att = F.pad(attention[i], (0, c - 512, 0, c - 512)) new_output.append(output) new_attention.append(att) elif n_s == 2: output1 = sequence_output[i][:512 - len_end] mask1 = attention_mask[i][:512 - len_end] att1 = attention[i][:, :512 - len_end, :512 - len_end] output1 = F.pad(output1, (0, 0, 0, c - 512 + len_end)) mask1 = F.pad(mask1, (0, c - 512 + len_end)) att1 = F.pad(att1, (0, c - 512 + len_end, 0, c - 512 + len_end)) output2 = sequence_output[i + 1][len_start:] mask2 = attention_mask[i + 1][len_start:] att2 = attention[i + 1][:, len_start:, len_start:] output2 = F.pad(output2, (0, 0, l_i - 512 + len_start, c - l_i)) mask2 = F.pad(mask2, (l_i - 512 + len_start, c - l_i)) att2 = F.pad(att2, [l_i - 512 + len_start, c - l_i, l_i - 512 + len_start, c - l_i]) mask = mask1 + mask2 + 1e-10 output = (output1 + output2) / mask.unsqueeze(-1) att = (att1 + att2) att = att / (att.sum(-1, keepdim=True) + 1e-10) new_output.append(output) new_attention.append(att) i += n_s sequence_output = torch.stack(new_output, dim=0) attention = torch.stack(new_attention, dim=0) return sequence_output, attention
[docs]def chunks(l, n): res = [] for i in range(0, len(l), n): assert len(l[i:i + n]) == n res += [l[i:i + n]] return res
[docs]class ReadDataset: def __init__(self, args, dataset: str, tokenizer, max_seq_Length: int = 1024, transformers: str = 'bert') -> None: self.transformers = transformers self.dataset = dataset self.tokenizer = tokenizer self.max_seq_Length = max_seq_Length self.args = args
[docs] def read(self, file_in: str): save_file = file_in.split('.json')[0] + '_' + self.transformers + '_' \ + self.dataset + '.pkl' if self.dataset == 'docred': return read_docred(self.args, self.transformers, file_in, save_file, self.tokenizer, self.max_seq_Length) else: raise RuntimeError("No read func for this dataset.")
[docs]def read_docred(args, transfermers, file_in, save_file, tokenizer, max_seq_length=1024): if os.path.exists(save_file): with open(file=save_file, mode='rb') as fr: features = pickle.load(fr) fr.close() print('load preprocessed data from {}.'.format(save_file)) return features else: max_len = 0 up512_num = 0 i_line = 0 pos_samples = 0 neg_samples = 0 features = [] docred_rel2id = json.load(open(f'{args.data_dir}/rel2id.json', 'r')) if file_in == "": return None with open(file_in, "r") as fh: data = json.load(fh) if transfermers == 'bert': # entity_type = ["ORG", "-", "LOC", "-", "TIME", "-", "PER", "-", "MISC", "-", "NUM"] entity_type = ["-", "ORG", "-", "LOC", "-", "TIME", "-", "PER", "-", "MISC", "-", "NUM"] for sample in tqdm(data, desc="Example"): sents = [] sent_map = [] entities = sample['vertexSet'] entity_start, entity_end = [], [] mention_types = [] for entity in entities: for mention in entity: sent_id = mention["sent_id"] pos = mention["pos"] entity_start.append((sent_id, pos[0])) entity_end.append((sent_id, pos[1] - 1)) mention_types.append(mention['type']) for i_s, sent in enumerate(sample['sents']): new_map = {} for i_t, token in enumerate(sent): tokens_wordpiece = tokenizer.tokenize(token) if (i_s, i_t) in entity_start: t = entity_start.index((i_s, i_t)) if transfermers == 'bert': mention_type = mention_types[t] special_token_i = entity_type.index(mention_type) special_token = ['[unused' + str(special_token_i) + ']'] else: special_token = ['*'] tokens_wordpiece = special_token + tokens_wordpiece # tokens_wordpiece = ["[unused0]"]+ tokens_wordpiece if (i_s, i_t) in entity_end: t = entity_end.index((i_s, i_t)) if transfermers == 'bert': mention_type = mention_types[t] special_token_i = entity_type.index(mention_type) + 50 special_token = ['[unused' + str(special_token_i) + ']'] else: special_token = ['*'] tokens_wordpiece = tokens_wordpiece + special_token # tokens_wordpiece = tokens_wordpiece + ["[unused1]"] # print(tokens_wordpiece,tokenizer.convert_tokens_to_ids(tokens_wordpiece)) new_map[i_t] = len(sents) sents.extend(tokens_wordpiece) new_map[i_t + 1] = len(sents) sent_map.append(new_map) if len(sents)>max_len: max_len=len(sents) if len(sents)>512: up512_num += 1 train_triple = {} if "labels" in sample: for label in sample['labels']: evidence = label['evidence'] r = int(docred_rel2id[label['r']]) if (label['h'], label['t']) not in train_triple: train_triple[(label['h'], label['t'])] = [ {'relation': r, 'evidence': evidence}] else: train_triple[(label['h'], label['t'])].append( {'relation': r, 'evidence': evidence}) entity_pos = [] for e in entities: entity_pos.append([]) mention_num = len(e) for m in e: start = sent_map[m["sent_id"]][m["pos"][0]] end = sent_map[m["sent_id"]][m["pos"][1]] entity_pos[-1].append((start, end,)) relations, hts = [], [] # Get positive samples from dataset for h, t in train_triple.keys(): relation = [0] * len(docred_rel2id) for mention in train_triple[h, t]: relation[mention["relation"]] = 1 evidence = mention["evidence"] relations.append(relation) hts.append([h, t]) pos_samples += 1 # Get negative samples from dataset for h in range(len(entities)): for t in range(len(entities)): if h != t and [h, t] not in hts: relation = [1] + [0] * (len(docred_rel2id) - 1) relations.append(relation) hts.append([h, t]) neg_samples += 1 assert len(relations) == len(entities) * (len(entities) - 1) if len(hts)==0: print(len(sent)) sents = sents[:max_seq_length - 2] input_ids = tokenizer.convert_tokens_to_ids(sents) input_ids = tokenizer.build_inputs_with_special_tokens(input_ids) i_line += 1 feature = {'input_ids': input_ids, 'entity_pos': entity_pos, 'labels': relations, 'hts': hts, 'title': sample['title'], } features.append(feature) print("# of documents {}.".format(i_line)) print("# of positive examples {}.".format(pos_samples)) print("# of negative examples {}.".format(neg_samples)) print("# {} examples len>512 and max len is {}.".format(up512_num, max_len)) with open(file=save_file, mode='wb') as fw: pickle.dump(features, fw) print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file)) return features