import inspect import warnings from collections import defaultdict, UserDict from timeit import default_timer from typing import Optional, Iterable, Union, Callable, List, Tuple import numpy as np import torch import torch.distributed as dist from torch import nn from torch.distributions.constraints import Constraint from transformers import LogitsProcessorList, StoppingCriteriaList, ConstrainedBeamSearchScorer, DisjunctiveConstraint, \ PhrasalConstraint, BeamScorer from transformers.generation import validate_stopping_criteria, BeamHypotheses from transformers.generation.utils import GenerateOutput, GenerationMixin, BeamSearchOutput, \ BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, SampleOutput, SampleEncoderDecoderOutput, \ SampleDecoderOnlyOutput # copied and modified from: huggingface utils.py from transformers.pytorch_utils import torch_int_div from ...utils.huggingface.beam_search import MyBeamSearchScorer class MyBeamSearchScorerViaNumpy(BeamScorer): def __init__( self, batch_size: int, num_beams: int, device: torch.device, length_penalty: Optional[float] = 1.0, do_early_stopping: Optional[bool] = False, num_beam_hyps_to_keep: Optional[int] = 1, num_beam_groups: Optional[int] = 1, **kwargs, ): self.num_beams = num_beams self.device = device self.length_penalty = length_penalty self.do_early_stopping = do_early_stopping self.num_beam_hyps_to_keep = num_beam_hyps_to_keep self.num_beam_groups = num_beam_groups self.group_size = self.num_beams // self.num_beam_groups self._is_init = False self._beam_hyps = [ BeamHypotheses( num_beams=self.num_beams, length_penalty=self.length_penalty, early_stopping=self.do_early_stopping, ) for _ in range(batch_size) ] self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) self.t_dict = defaultdict(lambda: 0.0) if not isinstance(num_beams, int) or num_beams <= 1: raise ValueError( f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1," " one should make use of `greedy_search` instead." ) if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): raise ValueError( "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be" f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." ) if "max_length" in kwargs: warnings.warn( "Passing `max_length` to BeamSearchScorer is deprecated and has no effect. " "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" ", or `group_beam_search(...)`." ) @property def is_done(self) -> bool: return self._done.all() def process( self, input_ids: torch.LongTensor, next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, next_indices: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor]: t0 = default_timer() device = input_ids.device input_ids = input_ids.cpu().numpy() next_scores = next_scores.cpu().numpy() next_tokens = next_tokens.cpu().numpy() next_indices = next_indices.cpu().numpy() self.t_dict['move_to_numpy'] += default_timer() - t0 t0 = default_timer() cur_len = input_ids.shape[-1] batch_size = len(self._beam_hyps) if not (batch_size == (input_ids.shape[0] // self.group_size)): if self.num_beam_groups > 1: raise ValueError( f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam " f"size of {self.group_size} is expected by the beam scorer." ) else: raise ValueError( f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " f"{self.group_size} is expected by the beam scorer." ) # next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) # next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) # next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) next_beam_scores = np.zeros((batch_size, self.group_size), dtype=next_scores.dtype) next_beam_tokens = np.zeros((batch_size, self.group_size), dtype=next_tokens.dtype) next_beam_indices = np.zeros((batch_size, self.group_size), dtype=next_indices.dtype) self.t_dict['create_array'] += default_timer() - t0 t0 = default_timer() # print( # f'beamscorer.process: ' # f'beam_hyps={len(self._beam_hyps)} ' # f'next_scores={next_scores.shape} ' # f'next_tokens={next_tokens.shape} ' # f'next_indices={next_indices.shape} ' # f'pad_token_id={pad_token_id} ' # f'eos_token_id={eos_token_id} ' # ) for batch_idx, beam_hyp in enumerate(self._beam_hyps): # t0 = default_timer() if self._done[batch_idx]: if self.num_beams < len(beam_hyp): raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated") if eos_token_id is None or pad_token_id is None: raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined") # pad the batch next_beam_scores[batch_idx, :] = 0 next_beam_tokens[batch_idx, :] = pad_token_id next_beam_indices[batch_idx, :] = 0 continue # self.t_dict['if_done'] += default_timer() - t0 # t0 = default_timer() # next tokens for this sentence beam_idx = 0 for beam_token_rank, (next_token, next_score, next_index) in enumerate( zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) ): batch_beam_idx = batch_idx * self.group_size + next_index # add to generated hypotheses if end of sentence if (eos_token_id is not None) and (next_token.item() == eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size if is_beam_token_worse_than_top_num_beams: continue if beam_indices is not None: beam_index = beam_indices[batch_beam_idx] beam_index = beam_index + (batch_beam_idx,) else: beam_index = None beam_hyp.add( input_ids[batch_beam_idx].copy(), # input_ids[batch_beam_idx].clone(), next_score.item(), beam_indices=beam_index, ) else: # add next predicted token since it is not eos_token next_beam_scores[batch_idx, beam_idx] = next_score next_beam_tokens[batch_idx, beam_idx] = next_token next_beam_indices[batch_idx, beam_idx] = batch_beam_idx beam_idx += 1 # once the beam for next step is full, don't add more tokens to it. if beam_idx == self.group_size: break # self.t_dict['inner_loop'] += default_timer() - t0 # t0 = default_timer() if beam_idx < self.group_size: raise ValueError( f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:" f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." ) # Check if we are done so that we can save a pad step if all(done) self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( next_scores[batch_idx].max().item(), cur_len ) # self.t_dict['postpare'] += default_timer() - t0 self.t_dict['body'] += default_timer() - t0 t0 = default_timer() ans = UserDict( { # "next_beam_scores": next_beam_scores.view(-1), # "next_beam_tokens": next_beam_tokens.view(-1), # "next_beam_indices": next_beam_indices.view(-1), "next_beam_scores": torch.tensor(next_beam_scores, device=device).view(-1), "next_beam_tokens": torch.tensor(next_beam_tokens, device=device).view(-1), "next_beam_indices": torch.tensor(next_beam_indices, device=device).view(-1), } ) self.t_dict['ans'] += default_timer() - t0 t0 = default_timer() return ans def finalize( self, input_ids: torch.LongTensor, final_beam_scores: torch.FloatTensor, final_beam_tokens: torch.LongTensor, final_beam_indices: torch.LongTensor, max_length: int, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, beam_indices: Optional[torch.LongTensor] = None, ) -> Tuple[torch.LongTensor]: batch_size = len(self._beam_hyps) device = input_ids.device input_ids = input_ids.cpu().numpy() final_beam_scores = final_beam_scores.cpu().numpy() # final_beam_tokens = final_beam_tokens.cpu().numpy() # final_beam_indices = final_beam_indices.cpu().numpy() # finalize all open beam hypotheses and add to generated hypotheses for batch_idx, beam_hyp in enumerate(self._beam_hyps): if self._done[batch_idx]: continue # all open beam hypotheses are added to the beam hypothesis # beam hypothesis class automatically keeps the best beams for beam_id in range(self.num_beams): batch_beam_idx = batch_idx * self.num_beams + beam_id final_score = final_beam_scores[batch_beam_idx].item() final_tokens = input_ids[batch_beam_idx] beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) # select the best hypotheses # sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) sent_lengths = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=int) best = [] best_indices = [] # best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) best_scores = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=np.float32) # retrieve best hypotheses for i, beam_hyp in enumerate(self._beam_hyps): sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) for j in range(self.num_beam_hyps_to_keep): best_hyp_tuple = sorted_hyps.pop() best_score = best_hyp_tuple[0] best_hyp = best_hyp_tuple[1] best_index = best_hyp_tuple[2] sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) # append hyp to lists best.append(best_hyp) # append indices to list best_indices.append(best_index) best_scores[i * self.num_beam_hyps_to_keep + j] = best_score # prepare for adding eos sent_lengths_max = sent_lengths.max().item() + 1 sent_max_len = int(min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max) # decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) decoded = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=int) if len(best_indices) > 0 and best_indices[0] is not None: # indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) indices = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=int) else: indices = None # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): assert pad_token_id is not None, "`pad_token_id` has to be defined" decoded.fill(pad_token_id) if indices is not None: indices.fill(-1) # fill with hypotheses and eos_token_id if the latter fits in for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): decoded[i, : sent_lengths[i]] = hypo if indices is not None: indices[i, : len(best_idx)] = torch.tensor(best_idx) if sent_lengths[i] < sent_max_len: decoded[i, sent_lengths[i]] = eos_token_id return UserDict( { # "sequences": decoded, # "sequence_scores": best_scores, # "beam_indices": indices, "sequences": torch.tensor(decoded, device=device), "sequence_scores": torch.tensor(best_scores, device=device), "beam_indices": torch.tensor(indices, device=device) if indices is not None else None, } ) # ChosenBeamSearchScorer = BeamSearchScorer # ChosenBeamSearchScorer = BeamSearchScorerForComparison ChosenBeamSearchScorer = MyBeamSearchScorer # ChosenBeamSearchScorer = MyBeamSearchScorerViaNumpy @torch.no_grad() def my_generate( self: GenerationMixin, inputs: Optional[torch.Tensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, penalty_alpha: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, encoder_no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, max_time: Optional[float] = None, max_new_tokens: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, num_beam_groups: Optional[int] = None, diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = None, renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, constraints: Optional[List[Constraint]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, remove_invalid_values: Optional[bool] = None, synced_gpus: Optional[bool] = False, exponential_decay_length_penalty: Optional[Tuple[int, float]] = None, suppress_tokens: Optional[List[int]] = None, begin_suppress_tokens: Optional[List[int]] = None, forced_decoder_ids: Optional[List[List[int]]] = None, **model_kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: print('my_generate called') # 0. Validate the `.generate()` call self._validate_model_class() self._validate_model_kwargs(model_kwargs.copy()) # 1. Set generation parameters if not already defined bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id num_beams = num_beams if num_beams is not None else self.config.num_beams length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id if eos_token_id is None and hasattr(self.config, "decoder"): eos_token_id = self.config.decoder.eos_token_id if pad_token_id is None and eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # 2. Define model inputs # inputs_tensor has to be defined # model_input_name is defined if model-specific keyword input is passed # otherwise model_input_name is None # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) batch_size = inputs_tensor.shape[0] # 3. Define other model kwargs model_kwargs["output_attentions"] = output_attentions model_kwargs["output_hidden_states"] = output_hidden_states model_kwargs["use_cache"] = use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, pad_token_id, eos_token_id ) # decoder-only models should use left-padding for generation if not self.config.is_encoder_decoder: if pad_token_id is not None and torch.sum(inputs_tensor[:, -1] == pad_token_id) > 0: logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." ) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created # and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name ) # 4. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, model_kwargs=model_kwargs, device=inputs_tensor.device, ) else: # if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor # 5. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] if max_length is None and max_new_tokens is None: warnings.warn( "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to " f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is " "deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend " "using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif max_length is None and max_new_tokens is not None: max_length = max_new_tokens + input_ids_seq_length elif max_length is not None and max_new_tokens is not None: raise ValueError( "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" " limit to the generated output length. Remove one of those arguments. Please refer to the" " documentation for more information. " "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) # default to config if still None max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length if min_length is not None and min_length > max_length: raise ValueError( f"Unfeasible length constraints: the minimum length ({min_length}) is larger than the maximum " f"length ({max_length})" ) if input_ids_seq_length >= max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {max_length}. This can lead to unexpected behavior. You should consider increasing " "`max_new_tokens`." ) # 6. determine generation mode is_constraint_gen_mode = constraints is not None or force_words_ids is not None is_contrastive_search_gen_mode = ( top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0 ) is_greedy_gen_mode = ( (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_sample_gen_mode = ( (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_gen_mode = ( (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_beam_sample_gen_mode = ( (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) is_group_beam_gen_mode = ( (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode and not is_contrastive_search_gen_mode ) if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if is_group_beam_gen_mode and do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) if self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." " Please make sure that you have put `input_ids` to the" f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" " running `.generate()`.", UserWarning, ) # 7. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=inputs_tensor, bad_words_ids=bad_words_ids, min_length=min_length, max_length=max_length, eos_token_id=eos_token_id, forced_bos_token_id=forced_bos_token_id, forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty, remove_invalid_values=remove_invalid_values, exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, renormalize_logits=renormalize_logits, suppress_tokens=suppress_tokens, begin_suppress_tokens=begin_suppress_tokens, forced_decoder_ids=forced_decoder_ids, ) # 8. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria ) # 9. go into different generation modes if is_greedy_gen_mode: if num_return_sequences > 1: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) # 10. run greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_contrastive_search_gen_mode: if num_return_sequences > 1: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." ) return self.contrastive_search( input_ids, top_k=top_k, penalty_alpha=penalty_alpha, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams, renormalize_logits=renormalize_logits, ) # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 12. run sample return my_sample( self, input_ids, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_gen_mode: if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 10. prepare beam search scorer beam_scorer = ChosenBeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 12. run beam search return my_beam_search( self, input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams, renormalize_logits=renormalize_logits, ) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 11. prepare beam search scorer beam_scorer = ChosenBeamSearchScorer( batch_size=batch_size * num_return_sequences, num_beams=num_beams, device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=num_beams * num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run beam sample return self.beam_sample( input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_group_beam_gen_mode: if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if num_beams % num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") if typical_p is not None: raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") # 10. prepare beam search scorer beam_scorer = MyBeamSearchScorerViaNumpy( batch_size=batch_size, num_beams=num_beams, max_length=stopping_criteria.max_length, device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, num_beam_groups=num_beam_groups, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 12. run beam search return self.group_beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_constraint_gen_mode: if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") if num_beams <= 1: raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") if do_sample: raise ValueError("`do_sample` needs to be false for constrained generation.") if num_beam_groups is not None and num_beam_groups > 1: raise ValueError("`num_beam_groups` not supported yet for constrained generation.") final_constraints = [] if constraints is not None: final_constraints = constraints if force_words_ids is not None: def typeerror(): raise ValueError( "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" f"of positive integers, but is {force_words_ids}." ) if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: typeerror() for word_ids in force_words_ids: if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() if any(not isinstance(token_ids, list) for token_ids in word_ids): typeerror() if any( any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) for token_ids in word_ids ): typeerror() constraint = DisjunctiveConstraint(word_ids) else: if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): typeerror() constraint = PhrasalConstraint(word_ids) final_constraints.append(constraint) # 10. prepare beam search scorer constrained_beam_scorer = ConstrainedBeamSearchScorer( constraints=final_constraints, batch_size=batch_size, num_beams=num_beams, device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 12. run beam search return self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) def my_sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) t_start = default_timer() t_dict = defaultdict(lambda: 0.0) this_peer_finished = False # used by synced_gpus only # auto-regressive generation while True: t0 = default_timer() if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) t_dict['prepare_inputs'] += default_timer() - t0 t0 = default_timer() # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) t_dict['model'] += default_timer() - t0 t0 = default_timer() if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) t_dict['sample'] += default_timer() - t0 t0 = default_timer() # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True t_dict['misc'] += default_timer() - t0 print( f'sample ' f't_total={default_timer() - t_start:.3f}s ' 't_dict=' + str({k: f'{v:.3}s' for k, v in t_dict.items()}) ) if return_dict_in_generate: if self.config.is_encoder_decoder: return SampleEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids def my_beam_search( self: GenerationMixin, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) t_start = default_timer() t_dict = defaultdict(lambda: 0.0) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break t0 = default_timer() model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # t_dict['prepare_inputs'] += default_timer() - t0 # t0 = default_timer() outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) # t_dict['model'] += default_timer() - t0 # t0 = default_timer() if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] t_dict['_'] += next_token_logits.sum().cpu().numpy() t_dict['model_2'] += default_timer() - t0 t0 = default_timer() # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) # t_dict['_'] += next_token_scores.sum().cpu().numpy() # t_dict['calc_next_token'] += default_timer() - t0 # t0 = default_timer() next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # t_dict['_'] += next_token_scores.sum().cpu().numpy() # t_dict['calc_scores'] += default_timer() - t0 # t0 = default_timer() # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores_processed,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # t_dict['related_return_dict_in_generate'] += default_timer() - t0 # t0 = default_timer() # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # t_dict['_'] += next_token_scores.sum().cpu().numpy() # t_dict['reshape'] += default_timer() - t0 # t0 = default_timer() # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) # t_dict['_'] += next_token_scores.sum().cpu().numpy() # t_dict['topk'] += default_timer() - t0 # t0 = default_timer() next_indices = torch_int_div(next_tokens, vocab_size) next_tokens = next_tokens % vocab_size # t_dict['sample'] += default_timer() - t0 # t0 = default_timer() # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, ) t_dict['scorer'] += default_timer() - t0 t0 = default_timer() beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 t_dict['updates'] += default_timer() - t0 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, ) print( f'beam_search ' f't_total={default_timer() - t_start:.3f}s ' 't_dict=' + str({k: f'{v:.3}s' for k, v in t_dict.items()}) # ' beam_scorer=' + str({k: f'{v:.3}s' for k, v in beam_scorer.t_dict.items()}) ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"]