Source code for deepke.relation_extraction.few_shot.generate_k_shot

import os
import numpy as np
import json
import shutil




[docs]def get_labels(path, name, negative_label="no_relation"): """See base class.""" with open(path + "/" + name, "r") as f: features = [] for line in f.readlines(): line = line.rstrip() if len(line) > 0: features.append(eval(line)) return features
[docs]def generate_k_shot(data_dir): Seed = [1, 2, 3, 4, 5] mode = 'k-shot' data_file = 'train.txt' path = 'data' output_dir = os.path.join(path, mode) dataset = get_labels(path, data_file) for seed in Seed: # Other datasets np.random.seed(seed) np.random.shuffle(dataset) # Set up dir k = 8 setting_dir = os.path.join(output_dir, f"{k}-{seed}") os.makedirs(setting_dir, exist_ok=True) label_list = {} for line in dataset: label = line['relation'] if label not in label_list: label_list[label] = [line] else: label_list[label].append(line) with open(os.path.join(setting_dir, "train.txt"), "w") as f: file_list = [] for label in label_list: for line in label_list[label][:k]: # train中每一类取前k个数据 f.writelines(json.dumps(line)) f.write('\n') f.close() shutil.copyfile('data/rel2id.json','data/k-shot/8-1/rel2id.json') shutil.copyfile('data/val.txt','data/k-shot/8-1/val.txt') shutil.copyfile('data/test.txt','data/k-shot/8-1/test.txt')