diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 2434fd41d..4c12ffc4d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -78,6 +78,8 @@ class Transducer(nn.Module): am_scale: float = 0.0, lm_scale: float = 0.0, warmup: float = 1.0, + delay_penalty: float = 0.0, + return_sym_delay: bool = False, ) -> torch.Tensor: """ Args: @@ -155,10 +157,31 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, + delay_penalty=delay_penalty, reduction="sum", return_grad=True, ) + sym_delay = None + if return_sym_delay: + B, S, T0 = px_grad.shape + T = T0 - 1 + if boundary is None: + offset = torch.tensor( + (T - 1) / 2, + dtype=px_grad.dtype, + device=px_grad.device, + ).expand(B, 1, 1) + total_syms = S * B + else: + offset = (boundary[:, 3] - 1) / 2 + total_syms = torch.sum(boundary[:, 2]) + offset = torch.arange( + T0, device=px_grad.device + ).reshape(1, 1, T0) - offset.reshape(B, 1, 1) + sym_delay = px_grad * offset + sym_delay = torch.sum(sym_delay) / total_syms + # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, @@ -188,7 +211,8 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction="sum", ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, sym_delay) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 47e2ae1c1..3bd0729b7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -322,6 +322,25 @@ def get_parser(): """, ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value to penalize symbol delay, this may be + needed when training with time masking, to avoid the time masking + encouraging the network to delay symbols. + """, + ) + + parser.add_argument( + "--return-sym-delay", + type=str2bool, + default=False, + help="""Whether to return `sym_delay` during training, this is a stat + to measure symbols emission delay, especially for time masking training. + """, + ) + parser.add_argument( "--use-fp16", type=str2bool, @@ -625,7 +644,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, sym_delay = model( x=feature, x_lens=feature_lens, y=y, @@ -633,6 +652,8 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + delay_penalty=params.delay_penalty, + return_sym_delay=params.return_sym_delay, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -662,6 +683,9 @@ def compute_loss( info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.return_sym_delay: + info["sym_delay"] = sym_delay.detach().cpu().item() + return loss, info @@ -905,6 +929,10 @@ def run(rank, world_size, args): assert ( params.causal_convolution ), "dynamic_chunk_training requires causal convolution" + else: + assert ( + params.delay_penalty == 0.0 + ), "delay_penalty is intended for dynamic_chunk_training" logging.info(params) diff --git a/icefall/utils.py b/icefall/utils.py index 3bfd5e5b1..2e2827ae6 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -546,8 +546,11 @@ class MetricsTracker(collections.defaultdict): ans = [] for k, v in self.items(): if k != "frames": - norm_value = float(v) / num_frames - ans.append((k, norm_value)) + if k != "sym_delay": + norm_value = float(v) / num_frames + ans.append((k, norm_value)) + else: + ans.append((k, float(v))) return ans def reduce(self, device):