Skip to content

Instantly share code, notes, and snippets.

@tadas-subonis
Created July 31, 2019 10:13
Show Gist options
  • Save tadas-subonis/795f20c5f2b4e549fa2aecc84d474db2 to your computer and use it in GitHub Desktop.
Save tadas-subonis/795f20c5f2b4e549fa2aecc84d474db2 to your computer and use it in GitHub Desktop.

Revisions

  1. tadas-subonis created this gist Jul 31, 2019.
    30 changes: 30 additions & 0 deletions ctc.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,30 @@
    from torch_baidu_ctc import ctc_loss

    class CTCLossT(nn.Module):
    def __init__(self):
    super(CTCLossT, self).__init__()
    self.blank = 0
    self.reduction = 'sum'

    def forward(self, log_probs, targets):
    #log_probs = log_probs.cpu()
    #targets = targets.cpu()
    #print(log_probs.shape, targets.shape)
    #targets = targets.permute(1, 0)
    batch_size = log_probs.size(1)

    #print(log_probs)

    #T = input_image_max_len
    T = log_probs.size(0)
    N = batch_size
    D = targets.size(1)

    input_lengths = torch.full(size=(N,), fill_value=D, dtype=torch.long)
    target_lengths = torch.full(size=(N,), fill_value=D, dtype=torch.long)

    with torch.backends.cudnn.flags(enabled=False):
    loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, zero_infinity=True)
    #print("loss", loss)

    return loss