Tools

deepke.relation_extraction.standard.tools.dataset module

deepke.relation_extraction.standard.tools.dataset.collate_fn(cfg)[source]
class deepke.relation_extraction.standard.tools.dataset.CustomDataset(fp)[source]

Bases: torch.utils.data.dataset.Dataset

默认使用 List 存储数据

deepke.relation_extraction.standard.tools.loss module

class deepke.relation_extraction.standard.tools.loss.LabelSmoothSoftmaxCEV1(lb_smooth=0.1, reduction='mean', ignore_index=- 100)[source]

Bases: torch.nn.modules.module.Module

forward(logits, label)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
deepke.relation_extraction.standard.tools.loss.taylor_softmax_v1(x, dim=1, n=4, use_log=False)[source]
class deepke.relation_extraction.standard.tools.loss.LogTaylorSoftmaxV1(dim=1, n=2)[source]

Bases: torch.nn.modules.module.Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class deepke.relation_extraction.standard.tools.loss.TaylorCrossEntropyLossV1(n=2, ignore_index=- 1, reduction='mean')[source]

Bases: torch.nn.modules.module.Module

forward(logits, labels)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class deepke.relation_extraction.standard.tools.loss.FocalLoss(gamma=0, alpha=None, size_average=True)[source]

Bases: torch.nn.modules.module.Module

training: bool
forward(input, target)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

deepke.relation_extraction.standard.tools.metrics module

class deepke.relation_extraction.standard.tools.metrics.Metric[source]

Bases: object

abstract reset()[source]

Resets the metric to to it’s initial state. This is called at the start of each epoch.

abstract update(*args)[source]

Updates the metric’s state using the passed batch output. This is called once for each batch.

abstract compute()[source]

Computes the metric based on it’s accumulated state. This is called at the end of each epoch. :return: the actual quantity of interest

class deepke.relation_extraction.standard.tools.metrics.PRMetric[source]

Bases: object

reset()[source]

重置为0

update(y_true: torch.Tensor, y_pred: torch.Tensor)[source]

更新tensor,保留值,取消原有梯度

compute()[source]

计算acc,p,r,f1并返回

deepke.relation_extraction.standard.tools.preprocess module

deepke.relation_extraction.standard.tools.preprocess.preprocess(cfg)[source]

数据预处理阶段

deepke.relation_extraction.standard.tools.serializer module

class deepke.relation_extraction.standard.tools.serializer.Serializer(never_split: Optional[List] = None, do_lower_case=True, do_chinese_split=False)[source]

Bases: object

serialize(text, never_split: Optional[List] = None)[source]

将一段文本按照制定拆分规则,拆分成一个词汇List Args :

text (String) : 所需拆分文本 never_split (List) : 不拆分的词,默认为空

Rerurn :

output_tokens (List): 拆分后的结果

static is_control(char)[source]

判断字符是否为控制字符 Arg :

char : 字符

Return :

bool : 判断结果

static is_whitespace(char)[source]

判断字符是否为空白字符 Arg :

char : 字符

Return :

bool : 判断结果

static is_chinese_char(cp)[source]

判断字符是否为中文字符 Arg :

cp (char): 字符

Return :

bool : 判断结果

static is_punctuation(char)[source]

判断字符是否为标点字符 Arg :

char : 字符

Return :

bool : 判断结果

deepke.relation_extraction.standard.tools.trainer module

deepke.relation_extraction.standard.tools.trainer.train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg)[source]
training the model.
Args:

epoch (int): number of training steps. model (class): model of training. dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders. optimizer (Callable): optimizer of training. criterion (Callable): loss criterion of training. device (torch.device): device of training. writer (class): output to tensorboard. cfg: configutation of training.

Return:

losses[-1] : the loss of training

deepke.relation_extraction.standard.tools.trainer.validate(epoch, model, dataloader, criterion, device, cfg)[source]
validating the model.
Args:

epoch (int): number of validating steps. model (class): model of validating. dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders. criterion (Callable): loss criterion of validating. device (torch.device): device of validating. cfg: configutation of validating.

Return:

f1 : f1 score loss : the loss of validating

deepke.relation_extraction.standard.tools.vocab module

class deepke.relation_extraction.standard.tools.vocab.Vocab(name: str = 'basic', init_tokens: Sequence = {'bos_token': '[BOS]', 'cls_token': '[CLS]', 'eos_token': '[EOS]', 'head_token': 'HEAD', 'mask_token': '[MASK]', 'pad_token': '[PAD]', 'sep_token': '[SEP]', 'tail_token': 'TAIL', 'unk_token': '[UNK]'})[source]

Bases: object

构建词汇表,增加词汇,删除低频词汇

add_words(words: Sequence)[source]

通过数组增加词汇 Arg :

words (List) : 增加的词汇组

trim(min_freq=2, verbose: Optional[bool] = True)[source]

当 word 词频低于 min_freq 时,从词库中删除 :param min_freq: 最低词频 :type min_freq: int :param verbose: 是否打印日志 :type verbose: bool