from torch_baidu_ctc import ctc_loss class CTCLossT(nn.Module): def __init__(self): super(CTCLossT, self).__init__() self.blank = 0 self.reduction = 'sum' def forward(self, log_probs, targets): #log_probs = log_probs.cpu() #targets = targets.cpu() #print(log_probs.shape, targets.shape) #targets = targets.permute(1, 0) batch_size = log_probs.size(1) #print(log_probs) #T = input_image_max_len T = log_probs.size(0) N = batch_size D = targets.size(1) input_lengths = torch.full(size=(N,), fill_value=D, dtype=torch.long) target_lengths = torch.full(size=(N,), fill_value=D, dtype=torch.long) with torch.backends.cudnn.flags(enabled=False): loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, zero_infinity=True) #print("loss", loss) return loss