From 9293edc62f4a3ebf769d66cc037d4e67953440f5 Mon Sep 17 00:00:00 2001 From: Mistmoon <2070620600@qq.com> Date: Tue, 8 Jul 2025 14:47:24 +0800 Subject: [PATCH] Add cr-ctc loss and ctc-decode in aishell (#1980) --- egs/aishell/ASR/RESULTS.md | 58 +++ egs/aishell/ASR/zipformer/ctc_decode.py | 540 ++++++++++++++++++++++++ egs/aishell/ASR/zipformer/train.py | 159 ++++++- 3 files changed, 738 insertions(+), 19 deletions(-) create mode 100755 egs/aishell/ASR/zipformer/ctc_decode.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 355d1516d..38a94d6c6 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,5 +1,63 @@ ## Results +### Aishell training results (zipformer + CR-CTC) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### medium-scale model, number of model parameters: 66218471, i.e., 66.2 M + +| decoding method | test | dev | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-search | 3.98 | 3.69 | --epoch 60 --avg 28 | +| ctc-prefix-beam-search | 3.98 | 3.70 | --epoch 60 --avg 21 | + +The training command using 2 32G-V100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 60 \ + --start-epoch 1 \ + --use-fp16 1 \ + --context-size 1 \ + --enable-musan 0 \ + --exp-dir zipformer/exp \ + --max-duration 500 \ + --base-lr 0.045 \ + --lr-batches 7500 \ + --lr-epochs 18 \ + --spec-aug-time-warp-factor 20 \ + --use-ctc 1 \ + --use-cr-ctc 1 \ + --use-transducer 0 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search ctc-prefix-beam-search; do + ./zipformer/ctc_decode.py \ + --epoch 60 \ + --avg 28 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --max-duration 600 \ + --decoding-method $m +done +``` + +Pretrained models, training logs, decoding logs, tensorboard and decoding results +are available at + + ### Aishell training results (Fine-tuning Pretrained Models) #### Whisper [./whisper](./whisper) diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..940aef1e5 --- /dev/null +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao, +# Zhifeng Han,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-greedy-search (with cr-ctc) +./zipformer/ctc_decode.py \ + --epoch 60 \ + --avg 28 \ + --exp-dir ./zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search +(2) ctc-prefix-beam-search (with cr-ctc) +./zipformer/ctc_decode.py \ + --epoch 60 \ + --avg 21 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --max-duration 600 \ + --decoding-method ctc-prefix-beam-search +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + ctc_greedy_search, + ctc_prefix_beam_search, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-greedy-search", + help="""Decoding method. + Supported values are: + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + (2) ctc-prefix-beam-search. Extract n paths with the given beam, the best + path of the n paths is the decoding result. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "beam": 4, # for prefix-beam-search + } + ) + return params + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + batch: dict, +) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + hyp_tokens = [] + hyps = [] + + if params.decoding_method == "ctc-greedy-search": + hyp_tokens = ctc_greedy_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + ) + elif params.decoding_method == "ctc-prefix-beam-search": + hyp_tokens = ctc_prefix_beam_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + + if params.decoding_method == "ctc-greedy-search": + return {"ctc-greedy-search" : hyps} + elif params.decoding_method == "ctc-prefix-beam-search": + return {"ctc-prefix-beam-search" : hyps} + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains 3 elements: + Respectively, they are cut_id, the reference transcript, and the predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + batch=batch, + ) + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results, char_level = True) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + compute_CER=True, + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-greedy-search", + "ctc-prefix-beam-search", + ) # support ctc-greedy-search and ctc-prefix-beam-search + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + params.device = device + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + dev_cuts = aishell.valid_cuts() + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index 0c389db55..3104665b0 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -64,6 +64,7 @@ from asr_datamodule import AishellAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -240,6 +241,27 @@ def add_model_arguments(parser: argparse.ArgumentParser): chunk left-context frames will be chosen randomly from this list; else not relevant.""", ) + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -380,6 +402,27 @@ def get_parser(): with this parameter before adding to the final loss.""", ) + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + parser.add_argument( "--seed", type=int, @@ -583,8 +626,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None model = AsrModel( encoder_embed=encoder_embed, @@ -594,9 +642,27 @@ def get_model(params: AttributeDict) -> nn.Module: encoder_dim=int(max(params.encoder_dim.split(","))), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, ) return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment def load_checkpoint_if_available( params: AttributeDict, @@ -723,6 +789,7 @@ def compute_loss( graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -739,6 +806,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ @@ -758,6 +827,21 @@ def compute_loss( y = graph_compiler.texts_to_ids(texts) y = k2.RaggedTensor(y).to(device) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): losses = model( x=feature, @@ -766,25 +850,40 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, ) - simple_loss, pruned_loss = losses[:2] + if params.use_ctc: + simple_loss, pruned_loss, ctc_loss, _, cr_loss = losses[:5] + else: + simple_loss, pruned_loss = losses[:2] - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) + loss = 0.0 - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -794,8 +893,13 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() return loss, info @@ -843,6 +947,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -869,6 +974,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -918,6 +1025,7 @@ def train_one_epoch( graph_compiler=graph_compiler, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1083,6 +1191,9 @@ def run(rank, world_size, args): params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + logging.info(params) logging.info("About to create model") @@ -1091,6 +1202,12 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1200,6 +1317,7 @@ def run(rank, world_size, args): optimizer=optimizer, graph_compiler=graph_compiler, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) @@ -1227,6 +1345,7 @@ def run(rank, world_size, args): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1293,6 +1412,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1310,6 +1430,7 @@ def scan_pessimistic_batches_for_oom( graph_compiler=graph_compiler, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad()