import warnings from collections import UserDict, defaultdict from typing import Optional, Tuple, Any import torch from transformers import BeamScorer, BeamSearchScorer from transformers.generation import BeamHypotheses from ...utils.torch_utils import first_several_nonzero_indices class MyBeamSearchScorer(BeamScorer): def __init__( self, batch_size: int, num_beams: int, device: torch.device, length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, num_beam_hyps_to_keep: Optional[int] = 1, num_beam_groups: Optional[int] = 1, **kwargs, ): self.num_beams = num_beams self.device = device self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep self.num_beam_groups = num_beam_groups self.group_size = self.num_beams // self.num_beam_groups self._is_init = False self._beam_hyps = [ BeamHypotheses( num_beams=self.num_beams, length_penalty=self.length_penalty, early_stopping=self.do_early_stopping, ) for _ in range(batch_size) ] self._done: torch.Tensor = \ torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) if not isinstance(num_beams, int) or num_beams <= 1: raise ValueError( f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1," " one should make use of `greedy_search` instead." ) if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): raise ValueError( "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be" f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." ) if "max_length" in kwargs: warnings.warn( "Passing `max_length` to BeamSearchScorer is deprecated and has no effect. " "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" ", or `group_beam_search(...)`." ) self.t_dict = defaultdict(lambda: 0.0) @property def is_done(self) -> bool: return self._done.all() def process( self, input_ids: torch.LongTensor, next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, next_indices: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor]: # t0 = default_timer() cur_len = input_ids.shape[-1] batch_size = len(self._beam_hyps) if not (batch_size == (input_ids.shape[0] // self.group_size)): if self.num_beam_groups > 1: raise ValueError( f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam " f"size of {self.group_size} is expected by the beam scorer." ) else: raise ValueError( f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " f"{self.group_size} is expected by the beam scorer." ) device = input_ids.device next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) batch_beam_indices = torch.arange(batch_size, device=device)[:, None] * self.group_size + next_indices # self.t_dict['prepare'] += default_timer() - t0 # t0 = default_timer() # for eos is_eos_and_non_done = (~self._done[:, None]) & (next_tokens == (eos_token_id or -42)) # self.t_dict['is_eos_and_non_done_sum'] += is_eos_and_non_done.sum().cpu().item() # self.t_dict['is_eos_and_non_done_nonzero'] += int(is_eos_and_non_done.sum().cpu().item() > 0) # self.t_dict['is_eos_and_non_done_count'] += 1 # self.t_dict['for-eos-create-a'] += default_timer() - t0 # t0 = default_timer() next_indices_selected = next_indices[is_eos_and_non_done] # self.t_dict['for-eos-create-b'] += default_timer() - t0 # t0 = default_timer() next_scores_selected = next_scores[is_eos_and_non_done] # self.t_dict['for-eos-create-c'] += default_timer() - t0 # t0 = default_timer() is_eos_and_non_done_indices = is_eos_and_non_done.nonzero() # self.t_dict['for-eos-create-d'] += default_timer() - t0 # t0 = default_timer() next_indices_selected = next_indices_selected.cpu().numpy() next_scores_selected = next_scores_selected.cpu().numpy() is_eos_and_non_done_indices = is_eos_and_non_done_indices.cpu().numpy() # self.t_dict['for-eos-to-cpu'] += default_timer() - t0 # t0 = default_timer() for i, (batch_idx, beam_token_rank) in enumerate(is_eos_and_non_done_indices): batch_beam_idx = batch_idx * self.group_size + next_indices_selected[i] # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: continue if beam_indices is not None: beam_index = beam_indices[batch_beam_idx] beam_index = beam_index + (batch_beam_idx,) else: beam_index = None self._beam_hyps[batch_idx].add( input_ids[batch_beam_idx].clone(), next_scores_selected[i].item(), beam_indices=beam_index, ) # self.t_dict['for-eos-loop'] += default_timer() - t0 # t0 = default_timer() # for non-eos first_several_non_eos = first_several_nonzero_indices( (next_tokens != (eos_token_id or -42)).int(), batch_enable=~self._done, k=self.num_beams) next_beam_scores[:] = next_scores[first_several_non_eos].reshape((-1, self.num_beams)) next_beam_tokens[:] = next_tokens[first_several_non_eos].reshape((-1, self.num_beams)) next_beam_indices[:] = batch_beam_indices[first_several_non_eos].reshape((-1, self.num_beams)) # self.t_dict['for-non-eos'] += default_timer() - t0 # t0 = default_timer() # those who are `done` next_beam_scores[self._done, :] = 0 if pad_token_id is not None: next_beam_tokens[self._done, :] = pad_token_id next_beam_indices[self._done, :] = 0 # Check if we are done so that we can save a pad step if all(done) next_scores_max = next_scores.max(dim=1)[0].cpu().numpy() self._done |= torch.tensor([ beam_hyp.is_done(next_scores_max[batch_idx], cur_len) for batch_idx, beam_hyp in enumerate(self._beam_hyps) ], device=device) # self.t_dict['done-related'] += default_timer() - t0 return UserDict( { "next_beam_scores": next_beam_scores.view(-1), "next_beam_tokens": next_beam_tokens.view(-1), "next_beam_indices": next_beam_indices.view(-1), } ) def finalize( self, input_ids: torch.LongTensor, final_beam_scores: torch.FloatTensor, final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue # all open beam hypotheses are added to the beam hypothesis # beam hypothesis class automatically keeps the best beams for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) # select the best hypotheses sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) best = [] best_indices = [] best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): best_hyp_tuple = sorted_hyps.pop() best_score = best_hyp_tuple[0] best_hyp = best_hyp_tuple[1] best_index = best_hyp_tuple[2] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) # append hyp to lists best.append(best_hyp) # append indices to list best_indices.append(best_index) best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos sent_lengths_max = sent_lengths.max().item() + 1 sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) if len(best_indices) > 0 and best_indices[0] is not None: indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) else: indices = None # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill_(pad_token_id) if indices is not None: indices.fill_(-1) # fill with hypotheses and eos_token_id if the latter fits in for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): decoded[i, : sent_lengths[i]] = hypo if indices is not None: indices[i, : len(best_idx)] = torch.tensor(best_idx) if sent_lengths[i] < sent_max_len: decoded[i, sent_lengths[i]] = eos_token_id return UserDict( { "sequences": decoded, "sequence_scores": best_scores, "beam_indices": indices, } ) class BeamSearchScorerForComparison: def __init__(self, **kwargs: Any): self.ours = MyBeamSearchScorer(**kwargs) self.theirs = BeamSearchScorer(**kwargs) @property def _beam_hyps(self): assert len(self.ours._beam_hyps) == len(self.theirs._beam_hyps) return self.ours._beam_hyps @property def num_beams(self): assert self.ours.num_beams == self.theirs.num_beams return self.ours.num_beams @property def is_done(self): assert self.ours.is_done == self.theirs.is_done return self.ours.is_done def process(self, *args: Any, **kwargs: Any): ours_output = self.ours.process(*args, **kwargs) theirs_output = self.theirs.process(*args, **kwargs) assert isinstance(ours_output, UserDict) and isinstance(theirs_output, UserDict) self._check_output_equality(ours_output, theirs_output) self._check_state_equality() return theirs_output def finalize(self, *args: Any, **kwargs: Any): ours_output = self.ours.finalize(*args, **kwargs) theirs_output = self.theirs.finalize(*args, **kwargs) assert isinstance(ours_output, UserDict) and isinstance(theirs_output, UserDict) self._check_output_equality(ours_output, theirs_output) self._check_state_equality() return theirs_output @staticmethod def _check_output_equality(ours_output: UserDict, theirs_output: UserDict): assert set(ours_output.keys()) == set(theirs_output.keys()) for k in ours_output.keys(): assert BeamSearchScorerForComparison._tensor_equals(ours_output[k], theirs_output[k]), \ f'output not equal. key={k} ' \ f'ours_output={ours_output[k]} theirs_output={theirs_output[k]} ' def _check_state_equality(self): assert self.ours.is_done == self.theirs.is_done @staticmethod def _tensor_equals(a: Optional[torch.Tensor], b: Optional[torch.Tensor]): return (a is None and b is None) or torch.allclose(a, b)