diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..a3dcb6930 --- /dev/null +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -0,0 +1,736 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao, +# Zhifeng Han,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-greedy-search (with cr-ctc) +./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 24 \ + --exp-dir ./zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.context_graph import ContextGraph, ContextState +from icefall.decode import ( + ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + DecodingResults, + make_pad_mask, + setup_logger, + store_transcripts_and_timestamps_withoutref, + str2bool, + write_error_stats, + parse_hyp_and_timestamp_ch, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-greedy-search", + help="""Decoding method. + Supported values are: + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + hyps = [] + + if params.decoding_method == "ctc-greedy-search" and params.max_sym_per_frame == 1: + res = ctc_greedy_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + return_timestamps = True, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + hyps, timestamps = parse_hyp_and_timestamp_ch( + res=res, + subsampling_factor=params.subsampling_factor, + word_table = lexicon.token_table, + # frame_shift_ms=params.frame_shift_ms, + ) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "ctc-greedy-search": + return {"ctc-greedy-search_" + key: (hyps, timestamps)} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: (hyps, timestamps)} + else: + return {f"beam_size_{params.beam_size}_" + key: (hyps, timestamps)} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, + with_timestamp: bool = False, +) -> Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + with_timestamp: + Whether to decode with timestamp. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains 4 elements: + Respectively, they are cut_id, the reference transcript, the predicted result and the decoded_timestamps. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "ctc-greedy-search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + if with_timestamp: + + for name, (hyps, timestamps_hyp) in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) and len(timestamps_hyp) == len(hyps) + for cut_id, hyp_words, ref_text, time_hyp in zip( + cut_ids, hyps, texts, timestamps_hyp + ): + this_batch.append((cut_id, ref_text, hyp_words, time_hyp)) + + results[name].extend(this_batch) + else: + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], List[Tuple[float, float]]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts_and_timestamps_withoutref(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + result_without_timestamp = [(res[0], res[1], res[2]) for res in results] + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{key}", + result_without_timestamp, + enable_log=True, + compute_CER=True, + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + # add decoding params + # params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-greedy-search", + ) # only support ctc-greedy-search + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + params.device = device + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + aishell = AishellAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = aishell.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = aishell.valid_dataloaders(dev_cuts) + + test_cuts = aishell.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + with_timestamp=True, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index dddfe52fa..a3a3473da 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -64,6 +64,7 @@ from asr_datamodule import AishellAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -239,6 +240,27 @@ def add_model_arguments(parser: argparse.ArgumentParser): chunk left-context frames will be chosen randomly from this list; else not relevant.""", ) + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -379,6 +401,27 @@ def get_parser(): with this parameter before adding to the final loss.""", ) + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + parser.add_argument( "--seed", type=int, @@ -582,8 +625,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None model = AsrModel( encoder_embed=encoder_embed, @@ -593,9 +641,27 @@ def get_model(params: AttributeDict) -> nn.Module: encoder_dim=int(max(params.encoder_dim.split(","))), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, ) return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment def load_checkpoint_if_available( params: AttributeDict, @@ -722,6 +788,7 @@ def compute_loss( graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -738,6 +805,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ @@ -757,6 +826,21 @@ def compute_loss( y = graph_compiler.texts_to_ids(texts) y = k2.RaggedTensor(y).to(device) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): losses = model( x=feature, @@ -765,25 +849,37 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, ) - simple_loss, pruned_loss = losses[:2] + simple_loss, pruned_loss, ctc_loss, _, cr_loss = losses[:5] - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) + loss = 0.0 - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -793,8 +889,13 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() return loss, info @@ -842,6 +943,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -868,6 +970,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -917,6 +1021,7 @@ def train_one_epoch( graph_compiler=graph_compiler, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1082,6 +1187,9 @@ def run(rank, world_size, args): params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + logging.info(params) logging.info("About to create model") @@ -1090,6 +1198,12 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1199,6 +1313,7 @@ def run(rank, world_size, args): optimizer=optimizer, graph_compiler=graph_compiler, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) @@ -1226,6 +1341,7 @@ def run(rank, world_size, args): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1292,6 +1408,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1309,6 +1426,7 @@ def scan_pessimistic_batches_for_oom( graph_compiler=graph_compiler, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() diff --git a/icefall/decode.py b/icefall/decode.py index 5d836bd48..2cdcbf491 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -26,7 +26,7 @@ import torch from icefall.context_graph import ContextGraph, ContextState from icefall.lm_wrapper import LmScorer from icefall.ngram_lm import NgramLm, NgramLmStateCost -from icefall.utils import add_eos, add_sos, get_texts +from icefall.utils import DecodingResults, add_eos, add_sos, get_texts DEFAULT_LM_SCALE = [ 0.01, @@ -1485,25 +1485,52 @@ def ctc_greedy_search( ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor, blank_id: int = 0, -) -> List[List[int]]: + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: """CTC greedy search. Args: ctc_output: (batch, seq_len, vocab_size) encoder_out_lens: (batch,) Returns: - List[List[int]]: greedy search result + Union[List[List[int]], DecodingResults]: greedy search result """ batch = ctc_output.shape[0] index = ctc_output.argmax(dim=-1) # (batch, seq_len) - hyps = [ - torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch) - ] - - hyps = [h[h != blank_id].tolist() for h in hyps] - - return hyps - + + hyps = [[] for _ in range(batch)] + # timestamps[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(batch)] + # scores[n][i] is the logits on which hyp[n][i] is decoded + scores = [[] for _ in range(batch)] + + for i in range(batch): + cur_pos = -1 + last = -1 + next_other = torch.where(index[i,0 : encoder_out_lens[i]] != last)[0] + flag = False # whether decoding words + while next_other.size(0) > 0: + cur_pos = cur_pos + 1 + next_other[0].item() + last = index[i,cur_pos].item() + if flag: # last word decoding finished + timestamps[i][-1] = timestamps[i][-1] + (cur_pos - 1, ) + if last != blank_id: + hyps[i].append(last) + timestamps[i].append((cur_pos,)) + scores[i].append(ctc_output[i, cur_pos, last].item()) + flag = True # Decoding words + else: + flag = False + + next_other = torch.where(index[i,cur_pos + 1 : encoder_out_lens[i]] != last)[0] + if flag: + timestamps[i][-1] = timestamps[i][-1] + (encoder_out_lens[i].item() - 1, ) + + if not return_timestamps: + return hyps + else: + return DecodingResults(hyps = hyps, timestamps = timestamps, scores = scores) @dataclass class Hypothesis: diff --git a/icefall/utils.py b/icefall/utils.py index ffb926566..5510f0476 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -360,7 +360,7 @@ class KeywordResult: class DecodingResults: # timestamps[i][k] contains the frame number on which tokens[i][k] # is decoded - timestamps: List[List[int]] + timestamps: List[List[Union[int, Tuple[int, int]]]] # hyps[i] is the recognition results, i.e., word IDs or token IDs # for the i-th utterance with fast_beam_search_nbest_LG. @@ -583,6 +583,40 @@ def store_transcripts_and_timestamps( s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) +def store_transcripts_and_timestamps_withoutref( + filename: Pathlike, + texts: Iterable[Tuple[str, List[str], List[str], List[Tuple[float]]]], +) -> None: + """Save predicted results and reference transcripts as well as their timestamps + to a file. + + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf8") as f: + for cut_id, ref, hyp, time_hyp in texts: + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + + if len(time_hyp) > 0: + if isinstance(time_hyp[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" + print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) + def write_error_stats( f: TextIO, @@ -1839,6 +1873,30 @@ def convert_timestamp( return time +def convert_timestamp_duration( + frames: List[Tuple[int, int]], + subsampling_factor: int, + frame_shift_ms: float = 10, +) -> List[Tuple[float, float]]: + """Convert frame numbers to time (in seconds) given subsampling factor + and frame shift (in milliseconds). + + Args: + frames: + A list of frame numbers after subsampling. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + Return: + Return the time in seconds corresponding to each given frame. + """ + frame_shift = frame_shift_ms / 1000.0 + time = [] + for f_start, f_end in frames: + time.append((round(f_start * subsampling_factor * frame_shift, ndigits=3), round(f_end * subsampling_factor * frame_shift, ndigits=3))) + + return time def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: """ @@ -1924,6 +1982,43 @@ def parse_hyp_and_timestamp( return hyps, timestamps +def parse_hyp_and_timestamp_ch( + res: DecodingResults, + subsampling_factor: int, + word_table: k2.SymbolTable, + frame_shift_ms: float = 10, +) -> Tuple[List[List[str]], List[List[float]]]: + """Parse hypothesis and timestamps of Chinese characters. + + Args: + res: + A DecodingResults object. + subsampling_factor: + The integer subsampling factor. + word_table: + The word symbol table. + frame_shift_ms: + The float frame shift used for feature extraction. + + Returns: + Return a list of hypothesis and timestamp. + """ + hyps = [] + timestamps = [] + + N = len(res.hyps) + assert len(res.timestamps) == N, (len(res.timestamps), N) + assert word_table is not None + + for i in range(N): + time = convert_timestamp_duration(res.timestamps[i], subsampling_factor, frame_shift_ms) + words_decoded = [word_table[idx] for idx in res.hyps[i]] + hyps.append(words_decoded) + timestamps.append(time) + assert len(time) == len(words_decoded), (len(time), len(words_decoded)) + + return hyps, timestamps + # `is_module_available` is copied from # https://github.com/pytorch/audio/blob/6bad3a66a7a1c7cc05755e9ee5931b7391d2b94c/torchaudio/_internal/module_utils.py#L9