From 752e16be1038211bada5d4f15eb4b59d3f6ae9f6 Mon Sep 17 00:00:00 2001 From: Erwan Date: Tue, 15 Aug 2023 10:26:38 +0200 Subject: [PATCH] Add BS to zipformer2 --- egs/librispeech/ASR/zipformer/frame_reducer.py | 6 +++++- egs/librispeech/ASR/zipformer/model.py | 3 +-- egs/librispeech/ASR/zipformer/train.py | 7 +++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/frame_reducer.py b/egs/librispeech/ASR/zipformer/frame_reducer.py index 60c4f37ba..099e76e22 100644 --- a/egs/librispeech/ASR/zipformer/frame_reducer.py +++ b/egs/librispeech/ASR/zipformer/frame_reducer.py @@ -27,6 +27,8 @@ import torch.nn.functional as F from icefall.utils import make_pad_mask +NON_BLANK_THRES = 0.9 + class FrameReducer(nn.Module): """The encoder output is first used to calculate @@ -72,7 +74,9 @@ class FrameReducer(nn.Module): N, T, C = x.size() padding_mask = make_pad_mask(x_lens) - non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) + non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * ( + ~padding_mask + ) if y_lens is not None or self.training is False: # Limit the maximum number of reduced frames diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 849ea798a..c1ab7b47f 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -181,7 +181,6 @@ class AsrModel(nn.Module): reduction: Specifies the reduction to apply to the output """ - # TODO: Add delay penalty to CTC Loss # Compute CTC log-prob ctc_output = self.ctc_output(encoder_out) # (N, T, C) encoder_out_fr = encoder_out @@ -211,7 +210,7 @@ class AsrModel(nn.Module): token_ids=targets, ) - # TODO: Find out why we need to do that but not in icefall + # TODO: Crash without this line supervision_segments = supervision_segments.to("cpu") decoding_graph = k2.ctc_graph( token_ids, modified=False, device=encoder_out.device diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index cf8ac1d30..ba10c18e4 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -538,6 +538,7 @@ def get_params() -> AttributeDict: "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 # parameters for zipformer + "ctc_beam_size": 10, "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. "warm_step": 2000, @@ -783,6 +784,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + warmup: float, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -823,9 +825,12 @@ def compute_loss( x=feature, x_lens=feature_lens, y=y, + supervisions=supervisions, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + ctc_beam_size=params.ctc_beam_size, + warmup=warmup, ) loss = 0.0 @@ -886,6 +891,7 @@ def compute_validation_loss( sp=sp, batch=batch, is_training=False, + warmup=(params.batch_idx_train / params.warm_step), ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -980,6 +986,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, + warmup=(params.batch_idx_train / params.warm_step), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info