Gitingest - vadimkantorov/ctc
Git
ingest
NEW
/llm.txt
GitHub
11.2k
Ingest
Exclude
Include
Include files under:
50kB
Private Repository
NEW
PAT is never stored in the backend
Used once for cloning, then discarded from memory
No browser caching
Cloned repos are deleted after processing
Get your token
Summary
Repository: vadimkantorov/ctc Files analyzed: 3 Estimated tokens: 5.5k
Download
Copy all
Directory Structure
Copy
Directory structure:
└── vadimkantorov-ctc/
├── README.md
├── ctc.py
└── example.py
Files Content
Copy
================================================ FILE: README.md ================================================ A primer on CTC implementation in pure Python PyTorch code. This impl is not suitable for real-world usage, only for experimentation and research on CTC modifications. Features: - CTC impl is in Python and its only loop is over time steps (parallelizes over batch and symbol dimensions) - Gradients are computed via PyTorch autograd instead of a separate beta computation - Viterbi path useful for forced alignment - Get alignment targets out of any CTC impl, so that label smoothing or reweighting can be applied [1, 2] - It might support double-backwards (not checked) ### Very rough time measurements ``` Device: cuda Log-probs shape (time X batch X channels): 128x256x32 Built-in CTC loss fwd 0.002052783966064453 bwd 0.0167086124420166 Custom CTC loss fwd 0.09685754776000977 bwd 0.14192843437194824 Custom loss matches: True Grad matches: True CE grad matches: True Device: cpu Log-probs shape (time X batch X channels): 128x256x32 Built-in CTC loss fwd 0.017746925354003906 bwd 0.21297860145568848 Custom CTC loss fwd 0.38710451126098633 bwd 5.190514087677002 Custom loss matches: True Grad matches: True CE grad matches: True ``` ### Very rought time measurements if custom logsumexp is used ``` Device: cuda Log-probs shape (time X batch X channels): 128x256x32 Built-in CTC loss fwd 0.009581804275512695 bwd 0.012355327606201172 Custom CTC loss fwd 0.09775996208190918 bwd 0.1494584083557129 Custom loss matches: True Grad matches: True CE grad matches: True Device: cpu Log-probs shape (time X batch X channels): 128x256x32 Built-in CTC loss fwd 0.017041444778442383 bwd 0.23205327987670898 Custom CTC loss fwd 0.3748452663421631 bwd 4.206061363220215 Custom loss matches: True Grad matches: True CE grad matches: True ``` ### Alignment image example  ### References (CTC) 1. A Novel Re-weighting Method for Connectionist Temporal Classification; Li et al; https://arxiv.org/abs/1904.10619 2. Focal CTC Loss for Chinese Optical Character Recognition on Unbalanced Datasets; Feng et al; https://www.hindawi.com/journals/complexity/2019/9345861/ 3. Improved training for online end-to-end speech recognition systems; Kim et al; https://arxiv.org/abs/1711.02212 4. Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks; Graves et all; https://www.cs.toronto.edu/~graves/icml_2006.pdf 5. Sequence Modeling With CTC, Hannun et al, https://distill.pub/2017/ctc/ 6. My two related gists: - Loop-based CTC forward: https://gist.github.com/vadimkantorov/c1aa417cffa1450b03716c740795f107 - CTC targets: https://gist.github.com/vadimkantorov/73e1915178f444b64f9ef01a1e96c1e4 8. Other CTC implementations: - https://github.com/rakeshvar/rnn_ctc/blob/master/nnet/ctc.py#L96 - https://github.com/artbataev/end2end/blob/master/pytorch_end2end/src/losses/forward_backward.cpp - https://github.com/jamesdanged/LatticeCtc - https://github.com/zh217/torch-asg/blob/master/torch_asg/native/force_aligned_lattice.cpp - https://github.com/amaas/stanford-ctc/blob/master/ctc/ctc.py - https://github.com/skaae/Lasagne-CTC/blob/master/ctc_cost.py - https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LossCTC.cpp#L37 - https://github.com/musyoku/chainer-gram-ctc https://github.com/musyoku/chainer-cuda-ctc - https://github.com/1ytic/warp-rnnt ### References (beam search) - https://towardsdatascience.com/beam-search-decoding-in-ctc-trained-neural-networks-5a889a3d85a7 - https://medium.com/corti-ai/ctc-networks-and-language-models-prefix-beam-search-explained-c11d1ee23306 - https://github.com/githubharald/CTCDecoder - https://github.com/githubharald/CTCWordBeamSearch - https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0 - https://github.com/wouterkool/stochastic-beam-search - https://github.com/mjansche/ctc_sampling - https://www.aclweb.org/anthology/D19-1331/ - https://arxiv.org/abs/1905.08760 - https://arxiv.org/abs/1804.07915 - http://proceedings.mlr.press/v97/cohen19a/cohen19a.pdf - https://github.com/corticph/prefix-beam-search/ ================================================ FILE: ctc.py ================================================ # TODO: try to replace fancy tensor indexing by gather / scatter import math import torch #@torch.jit.script def ctc_loss(log_probs : torch.Tensor, targets : torch.Tensor, input_lengths : torch.Tensor, target_lengths : torch.Tensor, blank : int = 0, reduction : str = 'none', finfo_min_fp32: float = torch.finfo(torch.float32).min, finfo_min_fp16: float = torch.finfo(torch.float16).min, alignment : bool = False) -> torch.Tensor: input_time_size, batch_size = log_probs.shape[:2] B = torch.arange(batch_size, device = input_lengths.device) _t_a_r_g_e_t_s_ = torch.cat([targets, targets[:, :1]], dim = -1) _t_a_r_g_e_t_s_ = torch.stack([torch.full_like(_t_a_r_g_e_t_s_, blank), _t_a_r_g_e_t_s_], dim = -1).flatten(start_dim = -2) diff_labels = torch.cat([torch.tensor([[False, False]], device = targets.device).expand(batch_size, -1), _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2]], dim = 1) # if zero = float('-inf') is used as neutral element, custom logsumexp must be used to avoid nan grad in torch.logsumexp zero_padding, zero = 2, torch.tensor(finfo_min_fp16 if log_probs.dtype == torch.float16 else finfo_min_fp32, device = log_probs.device, dtype = log_probs.dtype) log_probs_ = log_probs.gather(-1, _t_a_r_g_e_t_s_.expand(input_time_size, -1, -1)) log_alpha = torch.full((input_time_size, batch_size, zero_padding + _t_a_r_g_e_t_s_.shape[-1]), zero, device = log_probs.device, dtype = log_probs.dtype) log_alpha[0, :, zero_padding + 0] = log_probs[0, :, blank] log_alpha[0, :, zero_padding + 1] = log_probs[0, B, _t_a_r_g_e_t_s_[:, 1]] # log_alpha[1:, :, zero_padding:] = log_probs.gather(-1, _t_a_r_g_e_t_s_.expand(len(log_probs), -1, -1))[1:] for t in range(1, input_time_size): log_alpha[t, :, 2:] = log_probs_[t] + logadd(log_alpha[t - 1, :, 2:], log_alpha[t - 1, :, 1:-1], torch.where(diff_labels, log_alpha[t - 1, :, :-2], zero)) l1l2 = log_alpha[input_lengths - 1, B].gather(-1, torch.stack([zero_padding + target_lengths * 2 - 1, zero_padding + target_lengths * 2], dim = -1)) loss = -torch.logsumexp(l1l2, dim = -1) return loss if not alignment: return loss # below is for debugging, for real alignment use more efficient the distinct ctc_alignment(...) method path = torch.zeros(len(log_alpha), len(B), device = log_alpha.device, dtype = torch.int64) path[input_lengths - 1, B] = zero_padding + 2 * target_lengths - 1 + l1l2.argmax(dim = -1) for t, indices in reversed(list(enumerate(path))[1:]): indices_ = torch.stack([(indices - 2) * diff_labels[B, (indices - zero_padding).clamp(min = 0)], (indices - 1).clamp(min = 0), indices], dim = -1) path[t - 1] += (indices - 2 + log_alpha[t - 1, B].gather(-1, indices_).argmax(dim = -1)).clamp(min = 0) return torch.zeros_like(log_alpha).scatter_(-1, path.unsqueeze(-1), 1.0)[..., (zero_padding + 1)::2] #@torch.jit.script def ctc_alignment(log_probs : torch.Tensor, targets : torch.Tensor, input_lengths : torch.Tensor, target_lengths : torch.Tensor, blank: int = 0, finfo_min_fp32: float = torch.finfo(torch.float32).min, finfo_min_fp16: float = torch.finfo(torch.float16).min) -> torch.Tensor: input_time_size, batch_size = log_probs.shape[:2] B = torch.arange(batch_size, device = input_lengths.device) _t_a_r_g_e_t_s_ = torch.cat([ torch.stack([torch.full_like(targets, blank), targets], dim = -1).flatten(start_dim = -2), torch.full_like(targets[:, :1], blank) ], dim = -1) diff_labels = torch.cat([ torch.tensor([[False, False]], device = targets.device).expand(batch_size, -1), _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2] ], dim = 1) zero_padding, zero = 2, torch.tensor(finfo_min_fp16 if log_probs.dtype == torch.float16 else finfo_min_fp32, device = log_probs.device, dtype = log_probs.dtype) padded_t = zero_padding + _t_a_r_g_e_t_s_.shape[-1] log_alpha = torch.full((batch_size, padded_t), zero, device = log_probs.device, dtype = log_probs.dtype) log_alpha[:, zero_padding + 0] = log_probs[0, :, blank] log_alpha[:, zero_padding + 1] = log_probs[0, B, _t_a_r_g_e_t_s_[:, 1]] packmask = 0b11 packnibbles = 4 # packnibbles = 1 backpointers_shape = [len(log_probs), batch_size, int(math.ceil(padded_t / packnibbles))] backpointers = torch.zeros(backpointers_shape, device = log_probs.device, dtype = torch.uint8) backpointer = torch.zeros((backpointers_shape[-2], backpointers_shape[-1] * packnibbles), device = log_probs.device, dtype = torch.uint8) packshift = torch.tensor([[[6, 4, 2, 0]]], device = log_probs.device, dtype = torch.uint8) for t in range(1, input_time_size): prev = torch.stack([log_alpha[:, 2:], log_alpha[:, 1:-1], torch.where(diff_labels, log_alpha[:, :-2], zero)]) log_alpha[:, zero_padding:] = log_probs[t].gather(-1, _t_a_r_g_e_t_s_) + prev.logsumexp(dim = 0) backpointer[:, zero_padding:(zero_padding + prev.shape[-1] )] = prev.argmax(dim = 0) torch.sum(backpointer.unflatten(-1, (-1, packnibbles)) << packshift, dim = -1, out = backpointers[t]) # backpointers[t] = backpointer l1l2 = log_alpha.gather(-1, torch.stack([zero_padding + target_lengths * 2 - 1, zero_padding + target_lengths * 2], dim = -1)) path = torch.zeros(input_time_size, batch_size, device = log_alpha.device, dtype = torch.int64) path[input_lengths - 1, B] = zero_padding + target_lengths * 2 - 1 + l1l2.argmax(dim = -1) for t in range(input_time_size - 1, 0, -1): indices = path[t] backpointer = (backpointers[t].unsqueeze(-1) >> packshift).view_as(backpointer) #backpointer = backpointers[t] path[t - 1] += indices - backpointer.gather(-1, indices.unsqueeze(-1)).squeeze(-1).bitwise_and_(packmask) return torch.zeros_like(_t_a_r_g_e_t_s_, dtype = torch.int64).scatter_(-1, (path.t() - zero_padding).clamp(min = 0), torch.arange(input_time_size, device = log_alpha.device).expand(batch_size, -1))[:, 1::2] def ctc_alignment_uncompressed(log_probs : torch.Tensor, targets : torch.Tensor, input_lengths : torch.Tensor, target_lengths : torch.Tensor, blank: int = 0, pack_backpointers: bool = False, finfo_min_fp32: float = torch.finfo(torch.float32).min, finfo_min_fp16: float = torch.finfo(torch.float16).min) -> torch.Tensor: B = torch.arange(len(targets), device = input_lengths.device) _t_a_r_g_e_t_s_ = torch.cat([ torch.stack([torch.full_like(targets, blank), targets], dim = -1).flatten(start_dim = -2), torch.full_like(targets[:, :1], blank) ], dim = -1) diff_labels = torch.cat([ torch.as_tensor([[False, False]], device = targets.device).expand(len(B), -1), _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2] ], dim = 1) zero, zero_padding = torch.tensor(finfo_min_fp16 if log_probs.dtype is torch.float16 else finfo_min_fp32, device = log_probs.device, dtype = log_probs.dtype), 2 padded_t = zero_padding + _t_a_r_g_e_t_s_.shape[-1] log_alpha = torch.full((len(B), padded_t), zero, device = log_probs.device, dtype = log_probs.dtype) log_alpha[:, zero_padding + 0] = log_probs[0, :, blank] log_alpha[:, zero_padding + 1] = log_probs[0, B, _t_a_r_g_e_t_s_[:, 1]] packmask = 0b11 packnibbles = 4 padded_t = int(math.ceil(padded_t / packnibbles)) * packnibbles backpointers_shape = [len(log_probs), len(B), padded_t] backpointers = torch.zeros( backpointers_shape if not pack_backpointers else (backpointers_shape[:-1] + (padded_t // packnibbles, )), device = log_probs.device, dtype = torch.uint8 ) backpointer = torch.zeros(backpointers_shape[1:], device = log_probs.device, dtype = torch.uint8) packshift = torch.tensor([[[6, 4, 2, 0]]], device = log_probs.device, dtype = torch.uint8) for t in range(1, len(log_probs)): prev = torch.stack([log_alpha[:, 2:], log_alpha[:, 1:-1], torch.where(diff_labels, log_alpha[:, :-2], zero)]) log_alpha[:, 2:] = log_probs[t].gather(-1, _t_a_r_g_e_t_s_) + prev.logsumexp(dim = 0) backpointer[:, 2:(2 + prev.shape[-1])] = prev.argmax(dim = 0) if pack_backpointers: torch.sum(backpointer.view(len(backpointer), -1, 4) << packshift, dim = -1, out = backpointers[t]) else: backpointers[t] = backpointer l1l2 = log_alpha.gather( -1, torch.stack([zero_padding + target_lengths * 2 - 1, zero_padding + target_lengths * 2], dim = -1) ) path = torch.zeros(len(log_probs), len(B), device = log_alpha.device, dtype = torch.int64) path[input_lengths - 1, B] = zero_padding + target_lengths * 2 - 1 + l1l2.argmax(dim = -1) for t in range(len(path) - 1, 0, -1): indices = path[t] if pack_backpointers: backpointer = (backpointers[t].unsqueeze(-1) >> packshift).view_as(backpointer) else: backpointer = backpointers[t] path[t - 1] += indices - backpointer.gather(-1, indices.unsqueeze(-1)).squeeze(-1).bitwise_and_(packmask) return torch.zeros_like(_t_a_r_g_e_t_s_, dtype = torch.int64).scatter_( -1, (path.t() - zero_padding).clamp(min = 0), torch.arange(len(path), device = log_alpha.device).expand(len(B), -1) )[:, 1::2] def ctc_alignment_targets(log_probs, targets, input_lengths, target_lengths, blank = 0, ctc_loss = torch.nn.functional.ctc_loss, retain_graph = True): loss = ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = blank, reduction = 'sum') probs = log_probs.exp() # to simplify API we inline log_softmax gradient, i.e. next two lines are equivalent to: grad_logits, = torch.autograd.grad(loss, logits, retain_graph = True). gradient formula explained at https://stackoverflow.com/questions/35304393/trying-to-understand-code-that-computes-the-gradient-wrt-to-the-input-for-logsof grad_log_probs, = torch.autograd.grad(loss, log_probs, retain_graph = retain_graph) grad_logits = grad_log_probs - probs * grad_log_probs.sum(dim = -1, keepdim = True) temporal_mask = (torch.arange(len(log_probs), device = input_lengths.device, dtype = input_lengths.dtype).unsqueeze(1) < input_lengths.unsqueeze(0)).unsqueeze(-1) return (probs * temporal_mask - grad_logits).detach() def logadd(x0, x1, x2): # produces nan gradients in backward if -inf log-space zero element is used https://github.com/pytorch/pytorch/issues/31829 return torch.logsumexp(torch.stack([x0, x1, x2]), dim = 0) # use if -inf log-space zero element is used #return LogsumexpFunction.apply(x0, x1, x2) # produces inplace modification error https://github.com/pytorch/pytorch/issues/31819 #m = torch.max(torch.max(x0, x1), x2) #m = m.masked_fill(torch.isinf(m), 0) #res = (x0 - m).exp() + (x1 - m).exp() + (x2 - m).exp() #return res.log().add(m) class LogsumexpFunction(torch.autograd.function.Function): @staticmethod def forward(self, x0, x1, x2): m = torch.max(torch.max(x0, x1), x2) m = m.masked_fill_(torch.isinf(m), 0) e0 = (x0 - m).exp_() e1 = (x1 - m).exp_() e2 = (x2 - m).exp_() e = (e0 + e1).add_(e2).clamp_(min = 1e-16) self.save_for_backward(e0, e1, e2, e) return e.log_().add_(m) @staticmethod def backward(self, grad_output): e0, e1, e2, e = self.saved_tensors g = grad_output / e return (g * e0, g * e1, g * e2) ================================================ FILE: example.py ================================================ import time import matplotlib.pyplot as plt import torch import ctc T, B, C = 128, 256, 32 t = T // 2 - 4 blank = 0 device = 'cpu'#'cuda' seed = 1 atol = 1e-3 for set_seed in [torch.manual_seed] + ([torch.cuda.manual_seed_all] if device == 'cuda' else []): set_seed(seed) tictoc = lambda: (device == 'cuda' and torch.cuda.synchronize()) or time.time() logits = torch.randn(T, B, C, device = device).requires_grad_() targets = torch.randint(blank + 1, C, (B, t), dtype = torch.long, device = device) input_lengths = torch.full((B,), T, dtype = torch.long, device = device) target_lengths = torch.full((B,), t, dtype = torch.long, device = device) log_probs = logits.log_softmax(dim = -1) print('Device:', device) print('Log-probs shape (time X batch X channels):', 'x'.join(map(str, log_probs.shape))) tic = tictoc() builtin_ctc = torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = 0, reduction = 'none') toc = tictoc() builtin_ctc_grad, = torch.autograd.grad(builtin_ctc.sum(), logits, retain_graph = True) print('Built-in CTC loss', 'fwd', toc - tic, 'bwd', tictoc() - toc) tic = tictoc() custom_ctc = ctc.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = 0, reduction = 'none') toc = tictoc() custom_ctc_grad, = torch.autograd.grad(custom_ctc.sum(), logits, retain_graph = True) print('Custom CTC loss', 'fwd', toc - tic, 'bwd', tictoc() - toc) ce_alignment_targets = ctc.ctc_alignment_targets(log_probs, targets, input_lengths, target_lengths, blank = 0) ce_ctc = -ce_alignment_targets * log_probs ce_ctc_grad, = torch.autograd.grad(ce_ctc.sum(), logits, retain_graph = True) print('Custom loss matches:', torch.allclose(builtin_ctc, custom_ctc, atol = atol)) print('Grad matches:', torch.allclose(builtin_ctc_grad, custom_ctc_grad, atol = atol)) print('CE grad matches:', torch.allclose(builtin_ctc_grad, ce_ctc_grad, atol = atol)) alignment = ctc.ctc_alignment(log_probs, targets, input_lengths, target_lengths, blank = 0) a = torch.zeros(T, t); a[alignment[0, :target_lengths[0]], torch.arange(t)] = 1.0 plt.subplot(211) plt.title('Input-Output Viterbi alignment') plt.imshow(a.t().cpu(), origin = 'lower', aspect = 'auto') plt.xlabel('Input steps') plt.ylabel('Output steps') plt.subplot(212) plt.title('CTC alignment targets') a = ce_alignment_targets[:, 0, :] plt.imshow(a.t().cpu(), origin = 'lower', aspect = 'auto') plt.xlabel('Input steps') plt.ylabel(f'Output symbols, blank {blank}') plt.subplots_adjust(hspace = 0.5) plt.savefig('alignment.png')