import torch
from torch import nn
from torch.nn import functional as F
from transformers.configuration_bart import BartConfig
from .modeling_bart import BartModel, _prepare_bart_decoder_inputs
from ..utils import avg_token_embeddings, seq_to_mask,get_model_device
from functools import partial
from typing import Union
[docs]class PromptBartEncoder(nn.Module):
def __init__(self, encoder):
super(PromptBartEncoder, self).__init__()
self.bart_encoder = encoder
[docs] def forward(self, src_tokens, attention_mask=None, past_key_values=None):
encoder_dicts = self.bart_encoder(input_ids=src_tokens, attention_mask=attention_mask, past_key_values=past_key_values, return_dict=True, output_hidden_states=True)
return encoder_dicts.last_hidden_state, encoder_dicts.hidden_states
[docs]class PromptBartDecoder(nn.Module):
def __init__(self, decoder, pad_token_id, label_ids, use_prompt=False, prompt_len=10, learn_weights=False):
super(PromptBartDecoder, self).__init__()
self.bart_decoder = decoder
self.pad_token_id = pad_token_id
self.use_prompt = use_prompt
self.prompt_len = prompt_len
self.learn_weights = learn_weights
self.label_ids = label_ids
print(label_ids)
if self.learn_weights: # set learnable averge weights
self.averge_weights = nn.ParameterList(parameters=None)
for id in label_ids:
if len(id) > 1:
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id)).uniform_(1.0, 2.5)))
print(self.averge_weights)
mapping = [0, 2]
for id in label_ids:
mapping += id[:1]
mapping = torch.LongTensor(mapping)
else:
mapping = torch.LongTensor([0, 2]+label_ids)
self.label_start_id = min(label_ids)
self.label_end_id = max(label_ids)+1
self.register_buffer('mapping', mapping)
self.src_start_index = len(mapping)
hidden_size = decoder.embed_tokens.weight.size(1)
self.bart_mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size))
self.dropout_layer = nn.Dropout(0.3)
[docs] def forward(self, tgt_tokens, prompt_state):
cumsum = tgt_tokens.eq(1).flip(dims=[1]).cumsum(dim=-1)
tgt_pad_mask = cumsum.flip(dims=[-1]).ne(cumsum[:, -1:])
encoder_outputs = prompt_state.encoder_output # last_hidden_state
attention_mask = prompt_state.encoder_mask # attention_mask
first = prompt_state.first
src_tokens = prompt_state.src_tokens
past_key_values = prompt_state.past_key_values
# mapping target tokens
mapping_token_mask = tgt_tokens.lt(self.src_start_index)
mapped_tokens = tgt_tokens.masked_fill(tgt_tokens.ge(self.src_start_index), 0)
tag_mapped_tokens = self.mapping[mapped_tokens]
src_tokens_index = tgt_tokens - self.src_start_index # bsz x num_src_token
src_tokens_index = src_tokens_index.masked_fill(src_tokens_index.lt(0), 0)
if first is not None:
src_tokens = src_tokens.gather(index=first, dim=1)
word_mapped_tokens = src_tokens.gather(index=src_tokens_index, dim=1)
tokens = torch.where(mapping_token_mask, tag_mapped_tokens, word_mapped_tokens) # bsz x max_len
tokens = tokens.masked_fill(tgt_pad_mask, self.pad_token_id)
decoder_input_ids, _, causal_mask = _prepare_bart_decoder_inputs(
self.pad_token_id,
tokens,
decoder_input_ids=None,
decoder_padding_mask=None,
causal_mask_dtype=self.bart_decoder.embed_tokens.weight.dtype
)
if self.use_prompt:
assert past_key_values is not None
_, _, seqlen, _ = past_key_values[0]['self']['prev_value'].shape
tgt_len = decoder_input_ids.size(1)
temp_mask = torch.zeros(tgt_len, seqlen).to(causal_mask.device) #tgtlen, preseqlen
causal_mask = torch.cat([temp_mask, causal_mask],dim=1) #tgtlen, preseqlen+tgtlen
if self.training:
tokens = tokens[:, :-1]
decoder_pad_mask = tokens.eq(self.pad_token_id)
dict = self.bart_decoder(input_ids=tokens,
encoder_hidden_states=encoder_outputs, # last_hidden_state
encoder_padding_mask=attention_mask, # attention_mask
decoder_padding_mask=decoder_pad_mask,
decoder_causal_mask=causal_mask[:tokens.size(1), :self.prompt_len+tokens.size(1)],
output_hidden_states=True,
past_key_values=past_key_values,
return_dict=True)
else:
past_key_values = prompt_state.past_key_values
dict = self.bart_decoder(input_ids=tokens,
encoder_hidden_states=encoder_outputs,
encoder_padding_mask=attention_mask,
decoder_padding_mask=None,
decoder_causal_mask=None,
past_key_values=past_key_values,
use_cache=True,
return_dict=True)
hidden_state = dict.last_hidden_state # bsz x max_len x hidden_size
hidden_state = self.dropout_layer(hidden_state)
if not self.training:
prompt_state.past_key_values = dict.past_key_values
logits = hidden_state.new_full((hidden_state.size(0), hidden_state.size(1), self.src_start_index+src_tokens.size(-1)),
fill_value=-1e24)
# compute eos scores
eos_scores = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[2:3])) # bsz x max_len x 1
if self.learn_weights: # use averge_weights compute entity labels scores
tag_scores = None
idx = 0
for ids in self.label_ids: # bsz x max_len x num_class
if len(ids) <= 1:
temp_score = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[ids]))
else:
weight = F.softmax(self.averge_weights[idx])
temp_score = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[[ids[0]]])) * weight[0]
for i in range(1, len(ids)):
temp_score = temp_score + F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[[ids[i]]])) * weight[i]
idx += 1
if tag_scores is None:
tag_scores = temp_score
else:
tag_scores = torch.cat((tag_scores, temp_score), dim=2)
else:
tag_scores = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[self.label_start_id:self.label_end_id])) # bsz x max_len x num_class
# bsz x max_bpe_len x hidden_size
src_outputs = encoder_outputs
if hasattr(self, 'encoder_mlp'):
src_outputs = self.encoder_mlp(src_outputs)
if first is not None:
mask = first.eq(0) # bsz x 1 x max_word_len
# bsz x max_word_len x hidden_size
src_outputs = src_outputs.gather(index=first.unsqueeze(2).repeat(1, 1, src_outputs.size(-1)), dim=1)
else:
mask = attention_mask.eq(0)
# src_outputs = self.decoder.embed_tokens(src_tokens)
mask = mask.unsqueeze(1)
input_embed = self.dropout_layer(self.bart_decoder.embed_tokens(src_tokens)) # bsz x max_word_len x hidden_size
src_outputs = (src_outputs + input_embed)/2
word_scores = torch.einsum('blh,bnh->bln', hidden_state, src_outputs) # bsz x max_len x max_word_len
mask = mask.__or__(src_tokens.eq(2).cumsum(dim=1).ge(1).unsqueeze(1))
word_scores = word_scores.masked_fill(mask, -1e32)
logits[:, :, 1:2] = eos_scores
logits[:, :, 2:self.src_start_index] = tag_scores
logits[:, :, self.src_start_index:] = word_scores
return logits, prompt_state
[docs] def decode(self, tokens, state):
return self(tokens, state)[0][:, -1]
[docs]class PromptBartModel(nn.Module):
def __init__(self, tokenizer, label_ids, args):
super(PromptBartModel, self).__init__()
self.use_prompt = args.use_prompt
self.prompt_len = args.prompt_len
self.prompt_dim = args.prompt_dim
self.learn_weights = args.learn_weights
self.device = 'cuda' if torch.cuda.is_available else 'cpu'
bart_name = args.bart_name
self.bart_config = BartConfig.from_pretrained(bart_name)
self.bart_config.use_prompt = args.use_prompt
self.bart_config.preseqlen = args.prompt_len
bart_config = self.bart_config
bart_model = BartModel.from_pretrained(bart_name, config=bart_config)
num_tokens, _ = bart_model.encoder.embed_tokens.weight.shape
bart_model.resize_token_embeddings(len(tokenizer.unique_no_split_tokens)+num_tokens)
bart_model = avg_token_embeddings(tokenizer, bart_model, bart_name, num_tokens)
self.prompt_encoder = PromptBartEncoder(bart_model.encoder)
self.prompt_decoder = PromptBartDecoder(bart_model.decoder, tokenizer.pad_token_id, label_ids, self.use_prompt, self.prompt_len, self.learn_weights)
self.prompt_inputs = torch.arange(self.prompt_len).long()
self.encoder_prompt_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
self.encoder_mlp = nn.Sequential(
nn.Linear(bart_config.d_model, self.prompt_dim),
nn.Tanh(),
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
self.decoder_prompt_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
self.decoder_mlp = nn.Sequential(
nn.Linear(bart_config.d_model, self.prompt_dim),
nn.Tanh(),
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
self.prompt_cross_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
self.cross_mlp = nn.Sequential(
nn.Linear(bart_config.d_model, self.prompt_dim),
nn.Tanh(),
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
self.dropout = nn.Dropout(0.0)
[docs] def forward(self, src_tokens, tgt_tokens, src_seq_len, first):
prompt_state = self.generator(src_tokens, src_seq_len, first)
decoder_outputs, prompt_state = self.prompt_decoder(tgt_tokens, prompt_state)
return decoder_outputs
[docs] def generator(self, src_tokens, src_seq_len, first):
batch_size = src_tokens.size(0)
past_key_values = self.get_prompt(batch_size) if self.use_prompt else None
attention_mask = seq_to_mask(src_seq_len, max_len=src_tokens.size(1))
encoder_outputs, hidden_states = self.prompt_encoder(src_tokens, attention_mask=attention_mask, past_key_values=past_key_values)
prompt_state = PromptBartState(encoder_outputs, attention_mask, past_key_values, src_tokens, first, hidden_states[0], self.bart_config.preseqlen)
return prompt_state
[docs] def get_prompt(self, batch_size):
input_tokens = self.prompt_inputs.unsqueeze(0).expand(batch_size, -1).to(self.device)
# encoder prompt
encoder_embed = self.encoder_prompt_embed(input_tokens)
past_key_values = self.encoder_mlp(encoder_embed) #bsz, seqlen, layer*emb
bsz, seqlen, _ = past_key_values.shape
past_key_values = past_key_values.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) # key + value
# decoder prompt
decoder_embed = self.decoder_prompt_embed(input_tokens)
past_key_values2 = self.decoder_mlp(decoder_embed) # bsz, seqlen, layer*emb
past_key_values2 = past_key_values2.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
past_key_values2 = self.dropout(past_key_values2)
past_key_values2 = past_key_values2.permute([2, 0, 3, 1, 4]).split(2)
# cross prompt
cross_embed = self.prompt_cross_embed(input_tokens)
past_key_values_enc = self.cross_mlp(cross_embed) # bsz, seqlen, layer*emb
past_key_values_enc = past_key_values_enc.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
past_key_values_enc = self.dropout(past_key_values_enc)
past_key_values_enc = past_key_values_enc.permute([2, 0, 3, 1, 4]).split(2)
result = []
for i, key_val in enumerate(past_key_values):
temp_dict = {'self': {"prev_key": key_val[0].contiguous(),
"prev_value": key_val[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val.device).bool() #bsz, preseqlen
},
}
key_val2 = past_key_values2[i]
temp_dict['encoder_decoder'] = {"prev_key": key_val2[0].contiguous(),
"prev_value": key_val2[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val2.device).bool()
}
key_val_enc = past_key_values_enc[i]
temp_dict['encoder'] = {"prev_key": key_val_enc[0].contiguous(),
"prev_value": key_val_enc[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val_enc.device).bool()
}
result.append(temp_dict)
return result
[docs]class PromptBartState(object):
def __init__(self, encoder_output, encoder_mask, past_key_values, src_tokens, first, src_embed_outputs, preseqlen):
self.encoder_output = encoder_output
self.encoder_mask = encoder_mask
self.past_key_values = past_key_values
self.src_tokens = src_tokens
self.first = first
self.src_embed_outputs = src_embed_outputs
self.preseqlen = preseqlen
def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0):
if isinstance(state, torch.Tensor):
state = state.index_select(index=indices, dim=dim)
elif isinstance(state, list):
for i in range(len(state)):
assert state[i] is not None
state[i] = self._reorder_state(state[i], indices, dim)
elif isinstance(state, tuple):
tmp_list = []
for i in range(len(state)):
assert state[i] is not None
tmp_list.append(self._reorder_state(state[i], indices, dim))
state = tuple(tmp_list)
else:
raise TypeError(f"Cannot reorder data of type:{type(state)}")
return state
[docs] def reorder_state(self, indices: torch.LongTensor):
super().reorder_state(indices)
self.src_tokens = self._reorder_state(self.src_tokens, indices)
if self.first is not None:
self.first = self._reorder_state(self.first, indices)
self.src_embed_outputs = self._reorder_state(self.src_embed_outputs, indices)
if self.past_key_values is not None:
new = []
for layer in self.past_key_values:
new_layer = {}
for key1 in list(layer.keys()):
new_layer_ = {}
for key2 in list(layer[key1].keys()):
if layer[key1][key2] is not None:
layer[key1][key2] = self._reorder_state(layer[key1][key2], indices)
new_layer_[key2] = layer[key1][key2]
new_layer[key1] = new_layer_
new.append(new_layer)
self.past_key_values = new
[docs] def num_samples(self):
if self.encoder_output is not None:
return self.encoder_output.size(0)
else:
return None
[docs]class PromptGeneratorModel(nn.Module):
def __init__(self, prompt_model, max_length=20, max_len_a=0.0, num_beams=1,
do_sample=False, bos_token_id=None, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0, pad_token_id=0, restricter=None):
super(PromptGeneratorModel, self).__init__()
self.prompt_model = prompt_model
self.decoder = prompt_model.prompt_decoder
self.generate_func = partial(greedy_generate, decoder=self.decoder, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty, pad_token_id=pad_token_id,
restricter=restricter)
self.do_sample = do_sample
self.max_length = max_length
self.num_beams = num_beams
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.repetition_penalty = repetition_penalty
self.length_penalty = length_penalty
self.pad_token_id = pad_token_id
self.restricter = restricter
self.max_len_a = max_len_a
[docs] def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None, first=None):
"""
:param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor tgt_tokens: bsz x max_len'
:param torch.LongTensor src_seq_len: bsz
:param torch.LongTensor tgt_seq_len: bsz
:return:
"""
return self.prompt_model(src_tokens, tgt_tokens, src_seq_len, first)
[docs] def predict(self, src_tokens, src_seq_len=None, first=None):
"""
:param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor src_seq_len: bsz
:return:
"""
prompt_state = self.prompt_model.generator(src_tokens, src_seq_len, first) # encoder output
result = self.generate_func(tokens=None, state=prompt_state)
return result
[docs]@torch.no_grad()
def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1,
bos_token_id=None, eos_token_id=None, pad_token_id=0,
repetition_penalty=1, length_penalty=1.0, restricter=None):
if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id, restricter=restricter)
else:
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id, restricter=restricter)
return token_ids
def _no_beam_search_generate(decoder: PromptBartDecoder, state, tokens=None, max_length=20, max_len_a=0.0, bos_token_id=None,
eos_token_id=None,
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0,
restricter=None):
device = get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
batch_size = state.num_samples()
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `state`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if state.num_samples:
assert state.num_samples() == batch_size, "The number of samples in `tokens` and `state` should match."
if eos_token_id is None:
_eos_token_id = -1
else:
_eos_token_id = eos_token_id
scores = decoder.decode(tokens=tokens, state=state) # update state
if restricter is not None:
_, next_tokens = restricter(state, tokens, scores, num_beams=1)
else:
next_tokens = scores.argmax(dim=-1, keepdim=True)
token_ids = torch.cat([tokens, next_tokens], dim=1)
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1).__or__(next_tokens.squeeze(1).eq(eos_token_id))
# tokens = tokens[:, -1:]
if max_len_a!=0:
# (bsz x num_beams, )
if state.encoder_mask is not None:
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
else:
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long)
real_max_length = max_lengths.max().item()
else:
real_max_length = max_length
if state.encoder_mask is not None:
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length
else:
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
while cur_len < real_max_length:
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores)
if eos_token_id is not None and length_penalty != 1.0:
token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size
eos_mask = scores.new_ones(scores.size(1))
eos_mask[eos_token_id] = 0
eos_mask = eos_mask.unsqueeze(0).eq(1)
scores = scores.masked_scatter(eos_mask, token_scores)
if restricter is not None:
_, next_tokens = restricter(state, token_ids, scores, 1)
else:
next_tokens = scores.argmax(dim=-1, keepdim=True)
next_tokens = next_tokens.squeeze(-1)
if _eos_token_id!=-1:
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id)
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
tokens = next_tokens.unsqueeze(1)
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
end_mask = next_tokens.eq(_eos_token_id)
dones = dones.__or__(end_mask)
cur_len += 1
if dones.min() == 1:
break
return token_ids
def _beam_search_generate(decoder: PromptBartDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4,
bos_token_id=None, eos_token_id=None, do_sample=True,
repetition_penalty=1.0, length_penalty=None, pad_token_id=0,
restricter=None) -> torch.LongTensor:
assert do_sample is False
# beam search
device = get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
batch_size = state.num_samples
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `state`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if state.num_samples:
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match."
if eos_token_id is None:
_eos_token_id = -1
else:
_eos_token_id = eos_token_id
scores = decoder.decode(tokens=tokens, state=state)
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
if restricter is not None:
_next_scores, _next_tokens = restricter(state, tokens, scores, num_beams+1)
else:
# bsz x (num_beams+1)
_next_scores, _next_tokens = torch.topk(scores, num_beams+1, dim=1, largest=True, sorted=True)
indices = torch.arange(batch_size, dtype=torch.long).to(device)
indices = indices.repeat_interleave(num_beams)
state.reorder_state(indices)
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
if max_len_a!=0:
# (bsz x num_beams, )
if state.encoder_mask is not None:
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
else:
max_lengths = tokens.new_full((batch_size*num_beams, ), fill_value=max_length, dtype=torch.long)
real_max_length = max_lengths.max().item()
else:
real_max_length = max_length
if state.encoder_mask is not None:
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length
else:
max_lengths = tokens.new_full((batch_size*num_beams,), fill_value=max_length, dtype=torch.long)
hypos = [
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
not_eos_mask = _next_tokens.ne(_eos_token_id)
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
keep_mask = not_eos_mask.__and__(keep_mask)
next_tokens = _next_tokens.masked_select(keep_mask).view(batch_size, num_beams)
next_scores = _next_scores.masked_select(keep_mask).view(batch_size, num_beams)
rows, cols = not_eos_mask.eq(0)[:, :num_beams].nonzero(as_tuple=True)
if len(rows)>0:
for row, col in zip(rows.tolist(), cols.tolist()):
_token = torch.cat([tokens[row*num_beams], _next_tokens[row, col:col+1]], dim=0)
hypos[row].add(_token.clone(), _next_scores[row, col].item())
# (batch_size, cur_len)
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
dones = [False] * batch_size
beam_scores = next_scores.view(-1) # batch_size * num_beams
cur_len = token_ids.size(1)
# 0, num_beams, 2*num_beams, ...
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
while cur_len < real_max_length:
scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size)
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores)
if _eos_token_id!=-1:
max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id]
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores)
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size)
if restricter is not None:
next_scores, ids = restricter(state, token_ids, _scores, 2 * num_beams)
else:
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams)
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams)
next_tokens = ids % vocab_size # (batch_size, 2*num_beams)
not_eos_mask = next_tokens.ne(_eos_token_id)
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
keep_mask = not_eos_mask.__and__(keep_mask)
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams)
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
beam_scores = _next_scores.view(-1)
flag = True
if cur_len+1 == real_max_length:
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size)
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1)
else:
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams
if effective_eos_mask.sum().gt(0):
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx]
else:
flag = False
if flag:
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(),
eos_beam_idx.tolist()):
if not dones[batch_idx]:
score = next_scores[batch_idx, beam_ind].item()
if _eos_token_id!=-1:
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
else:
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten
state.reorder_state(reorder_inds)
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)
for batch_idx in range(batch_size):
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \
max_lengths[batch_idx*num_beams]==cur_len+1
cur_len += 1
if all(dones):
break
# select the best hypotheses
tgt_len = token_ids.new_zeros(batch_size)
best = []
for i, hypotheses in enumerate(hypos):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
if _eos_token_id!=-1:
best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id])
tgt_len[i] = len(best_hyp)
best.append(best_hyp)
# generate target batch
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best):
decoded[i, :tgt_len[i]] = hypo
return decoded
[docs]class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
[docs] def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
[docs] def is_done(self, best_sum_logprobs):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty