Utils¶
deepke.name_entity_re.few_shot.utils.util module¶
- deepke.name_entity_re.few_shot.utils.util.avg_token_embeddings(tokenizer: transformers.tokenization_bart.BartTokenizer, bart_model: transformers.modeling_bart.BartModel, bart_name, num_tokens)[source]¶
when initial added tokens, use their averge token emebddings
- Parameters
- Raises
RuntimeError – [description]
- Returns
[description]
- Return type
[type]
- deepke.name_entity_re.few_shot.utils.util.seq_to_mask(seq_len, max_len)[source]¶
[get attention mask with sequence length]
- Parameters
seq_len ([torch.tensor]) – [shape: bsz, each sequence length in a batch]
- deepke.name_entity_re.few_shot.utils.util.get_loss(tgt_tokens, tgt_seq_len, pred)[source]¶
- Parameters
tgt_tokens – bsz x max_len, 包含了的[sos, token, eos]
pred – bsz x max_len-1 x vocab_size
- Returns
- deepke.name_entity_re.few_shot.utils.util.convert_preds_to_outputs(preds, raw_words, mapping, tokenizer)[source]¶
convet model predicitons to BIO outputs
- Parameters
preds ([torch.Tensor]) – [prompt model predictions, (bsz x seq_len x labels)]
raw_words ([List]) – [source raw words]
mapping ([dict]) – [map entity labels to <<>>]
tokenizer – [BartTokenizer]
- Returns
[each item length equal to raw_words, BIO format.]
- Return type
[outputs (List)]