From 0cad3362771ee5f85a985d617636e257043809b8 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Tue, 13 Jun 2023 07:59:05 -0400 Subject: [PATCH] remove unwanted files --- egs/librispeech/ASR/zipformer_ctc/__init__.py | 0 .../ASR/zipformer_ctc/asr_datamodule.py | 1 - egs/librispeech/ASR/zipformer_ctc/decode.py | 886 ------------- egs/librispeech/ASR/zipformer_ctc/decoder.py | 298 ----- .../ASR/zipformer_ctc/encoder_interface.py | 1 - egs/librispeech/ASR/zipformer_ctc/export.py | 240 ---- .../ASR/zipformer_ctc/label_smoothing.py | 1 - egs/librispeech/ASR/zipformer_ctc/model.py | 158 --- egs/librispeech/ASR/zipformer_ctc/optim.py | 1 - egs/librispeech/ASR/zipformer_ctc/scaling.py | 1 - .../ASR/zipformer_ctc/scaling_converter.py | 1 - .../ASR/zipformer_ctc/subsampling.py | 1 - egs/librispeech/ASR/zipformer_ctc/train.py | 1135 ----------------- .../ASR/zipformer_ctc/transformer.py | 1 - .../ASR/zipformer_ctc/zipformer.py | 1 - 15 files changed, 2726 deletions(-) delete mode 100644 egs/librispeech/ASR/zipformer_ctc/__init__.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py delete mode 100755 egs/librispeech/ASR/zipformer_ctc/decode.py delete mode 100644 egs/librispeech/ASR/zipformer_ctc/decoder.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/encoder_interface.py delete mode 100755 egs/librispeech/ASR/zipformer_ctc/export.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/label_smoothing.py delete mode 100644 egs/librispeech/ASR/zipformer_ctc/model.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/optim.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/scaling.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/scaling_converter.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/subsampling.py delete mode 100755 egs/librispeech/ASR/zipformer_ctc/train.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/transformer.py delete mode 120000 egs/librispeech/ASR/zipformer_ctc/zipformer.py diff --git a/egs/librispeech/ASR/zipformer_ctc/__init__.py b/egs/librispeech/ASR/zipformer_ctc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py deleted file mode 120000 index fa1b8cca3..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/decode.py b/egs/librispeech/ASR/zipformer_ctc/decode.py deleted file mode 100755 index 7f605e2c8..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/decode.py +++ /dev/null @@ -1,886 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_ctc_model, get_params -from transformer import encoder_padding_mask - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_n_best_list, - rescore_with_rnn_lm, - rescore_with_whole_lattice, -) -from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel -from icefall.utils import ( - AttributeDict, - get_texts, - load_averaged_model, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=77, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=55, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--method", - type=str, - default="attention-decoder", - help="""Decoding method. - Supported values are: - - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece - model, i.e., lang_dir/bpe.model, to convert word pieces to words. - It needs neither a lexicon nor an n-gram LM. - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - - (5) attention-decoder. Extract n paths from the LM rescored - lattice, the path with the highest score is the decoding result. - - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume - you have trained an RNN LM using ./rnn_lm/train.py - - (7) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="The lang dir", - ) - - parser.add_argument( - "--lm-dir", - type=str, - default="data/lm", - help="""The n-gram LM dir. - It should contain either G_4_gram.pt or G_4_gram.fst.txt - """, - ) - - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is rnn-lm. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is rnn-lm. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", - type=str2bool, - default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - rnn_lm_model: Optional[nn.Module], - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - rnn_lm_model: - The neural model for RNN LM. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. Note: If it decodes to nothing, then return None. - """ - if HLG is not None: - device = HLG.device - else: - device = H.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - nnet_output, _ = model.encoder(feature, feature_lens) - ctc_output = model.ctc_output(nnet_output) - # nnet_output is (N, T, C) - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - if H is None: - assert HLG is not None - decoding_graph = HLG - else: - assert HLG is None - assert bpe_model is not None - decoding_graph = H - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "ctc-decoding": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] - key = "ctc-decoding" - return {key: hyps} - - if params.method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa - - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - "rnn-lm", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - - nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(nnet_output.size(0), supervisions) - mask = mask.to(nnet_output.device) if mask is not None else None - mmodel = model.decoder.module if hasattr(model.decoder, "module") else model.decoder - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - elif params.method == "attention-decoder": - # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=None, - ) - - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=mmodel, - memory=nnet_output, - memory_key_padding_mask=mask, - sos_id=sos_id, - eos_id=eos_id, - nbest_scale=params.nbest_scale, - ) - elif params.method == "rnn-lm": - # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=None, - ) - - best_path_dict = rescore_with_rnn_lm( - lattice=rescored_lattice, - num_paths=params.num_paths, - rnn_lm_model=rnn_lm_model, - model=mmodel, - memory=nnet_output, - memory_key_padding_mask=mask, - sos_id=sos_id, - eos_id=eos_id, - blank_id=0, - nbest_scale=params.nbest_scale, - ) - else: - assert False, f"Unsupported decoding method: {params.method}" - - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - rnn_lm_model: Optional[nn.Module], - HLG: Optional[k2.Fsa], - H: Optional[k2.Fsa], - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - rnn_lm_model: - The neural model for RNN LM. - HLG: - The decoding graph. Used only when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - word_table: - It is the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - rnn_lm_model=rnn_lm_model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - batch=batch, - word_table=word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - if hyps_dict is not None: - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - else: - assert len(results) > 0, "It should not decode to empty in the first batch!" - this_batch = [] - hyp_words = [] - for ref_text in texts: - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - for lm_scale in results.keys(): - results[lm_scale].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], -): - if params.method in ("attention-decoder", "rnn-lm"): - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=enable_log - ) - test_set_wers[key] = wer - - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - args.lm_dir = Path(args.lm_dir) - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - params.vocab_size = num_classes - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id - - params.num_classes = num_classes - params.sos_id = sos_id - params.eos_id = eos_id - - if params.method == "ctc-decoding": - HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - else: - H = None - bpe_model = None - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) - ) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.method in ( - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - "rnn-lm", - ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) - G = k2.Fsa.from_dict(d) - - if params.method in [ - "whole-lattice-rescoring", - "attention-decoder", - "rnn-lm", - ]: - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - logging.info("About to create model") - model = get_ctc_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - rnn_lm_model = None - if params.method == "rnn-lm": - rnn_lm_model = RnnLmModel( - vocab_size=params.num_classes, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, - ) - if params.rnn_lm_avg == 1: - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", - rnn_lm_model, - ) - rnn_lm_model.to(device) - else: - rnn_lm_model = load_averaged_model( - params.rnn_lm_exp_dir, - rnn_lm_model, - params.rnn_lm_epoch, - params.rnn_lm_avg, - device, - ) - rnn_lm_model.eval() - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - rnn_lm_model=rnn_lm_model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/zipformer_ctc/decoder.py b/egs/librispeech/ASR/zipformer_ctc/decoder.py deleted file mode 100644 index 8dec048a1..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/decoder.py +++ /dev/null @@ -1,298 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -import torch -import torch.nn as nn -import torch.nn.functional as F -from label_smoothing import LabelSmoothingLoss -from torch.nn.utils.rnn import pad_sequence -from transformer import PositionalEncoding, TransformerDecoderLayer - - -class Decoder(nn.Module): - """This class implements Transformer based decoder for an attention-based encoder-decoder - model. - """ - - def __init__( - self, - num_layers: int, - num_classes: int, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - dropout: float = 0.1, - normalize_before: bool = True, - ): - """ - Args: - num_layers: - Number of layers. - num_classes: - Number of tokens of the modeling unit including blank. - d_model: - Dimension of the input embedding, and of the decoder output. - """ - super().__init__() - - if num_layers > 0: - self.decoder_num_class = num_classes # bpe model already has sos/eos symbol - - self.decoder_embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - decoder_norm = nn.LayerNorm(d_model) - else: - decoder_norm = None - - self.decoder = nn.TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=num_layers, - norm=decoder_norm, - ) - - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) - self.decoder_criterion = LabelSmoothingLoss() - else: - self.decoder_criterion = None - - @torch.jit.export - def forward( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id - Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. - """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss - - @torch.jit.export - def decoder_nll( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[torch.Tensor], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). - """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. - if isinstance(token_ids[0], torch.Tensor): - # This branch is executed by torchscript in C++. - # See https://github.com/k2-fsa/k2/pull/870 - # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 - token_ids = [tolist(t) for t in token_ids] - - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) - - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, - reduction="none", - ) - - nll = nll.view(pred_pad.shape[0], -1) - - return nll - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - return [[sos_id] + utt for utt in token_ids] - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - Return: - Return a new list-of-list, where each sublist ends - with EOS ID. - """ - return [utt + [eos_id] for utt in token_ids] - - -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: - """Generate a length mask for input. - The masked position are filled with True, - Unmasked positions are filled with False. - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - Returns: - Tensor: - a bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - For instance, if sz is 3, it returns:: - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - Args: - sz: mask size - Returns: - A square mask of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def tolist(t: torch.Tensor) -> List[int]: - """Used by jit""" - return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py b/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py deleted file mode 120000 index b8529e0b7..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/export.py b/egs/librispeech/ASR/zipformer_ctc/export.py deleted file mode 100755 index 0ff50f128..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/export.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. - -import argparse -import logging -from pathlib import Path - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_ctc_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer_ctc/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=True, - help="""True to save a model after applying torch.jit.script. - """, - ) - - add_model_arguments(parser) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - params.vocab_size = num_classes - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = get_ctc_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit: - logging.info("Using torch.jit.script") - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - convert_scaled_to_non_scaled(model, inplace=True) - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py b/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py deleted file mode 120000 index 08734abd7..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/model.py b/egs/librispeech/ASR/zipformer_ctc/model.py deleted file mode 100644 index 2aeb8a072..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/model.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from transformer import encoder_padding_mask - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.utils import encode_supervisions - - -class CTCModel(nn.Module): - """It implements a CTC model with an auxiliary attention head.""" - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - encoder_dim: int, - vocab_size: int, - ): - """ - Args: - encoder: - An instance of `EncoderInterface`. The shared encoder for the CTC and attention - branches - decoder: - An instance of `nn.Module`. This is the decoder for the attention branch. - encoder_dim: - Dimension of the encoder output. - decoder_dim: - Dimension of the decoder output. - vocab_size: - Number of tokens of the modeling unit including blank. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - - self.encoder = encoder - self.ctc_output = nn.Sequential( - nn.Dropout(p=0.1), - nn.Linear(encoder_dim, vocab_size), - nn.LogSoftmax(dim=-1), - ) - self.decoder = decoder - - @torch.jit.ignore - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - supervisions: torch.Tensor, - graph_compiler: BpeCtcTrainingGraphCompiler, - subsampling_factor: int = 1, - beam_size: int = 10, - reduction: str = "sum", - use_double_scores: bool = False, - ) -> torch.Tensor: - """ - Args: - x: - Tensor of dimension (N, T, C) where N is the batch size, - T is the number of frames, and C is the feature dimension. - x_lens: - Tensor of dimension (N,) where N is the batch size. - supervisions: - Supervisions are used in training. - graph_compiler: - It is used to compile a decoding graph from texts. - subsampling_factor: - It is used to compute the `supervisions` for the encoder. - beam_size: - Beam size used in `k2.ctc_loss`. - reduction: - Reduction method used in `k2.ctc_loss`. - use_double_scores: - If True, use double precision in `k2.ctc_loss`. - Returns: - Return the CTC loss, attention loss, and the total number of frames. - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - - nnet_output, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) - # compute ctc log-probs - ctc_output = self.ctc_output(nnet_output) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=subsampling_factor - ) - num_frames = supervision_segments[:, 2].sum().item() - - # Works with a BPE model - token_ids = graph_compiler.texts_to_ids(texts) - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - ctc_output, - supervision_segments.cpu(), - allow_truncate=subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=beam_size, - reduction=reduction, - use_double_scores=use_double_scores, - ) - - if self.decoder is not None: - nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mmodel = ( - self.decoder.module if hasattr(self.decoder, "module") else self.decoder - ) - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) - mask = encoder_padding_mask(nnet_output.size(0), supervisions) - mask = mask.to(nnet_output.device) if mask is not None else None - att_loss = mmodel.forward( - nnet_output, - mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = torch.tensor([0]) - - return ctc_loss, att_loss, num_frames diff --git a/egs/librispeech/ASR/zipformer_ctc/optim.py b/egs/librispeech/ASR/zipformer_ctc/optim.py deleted file mode 120000 index 81ac4a89a..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/optim.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/scaling.py b/egs/librispeech/ASR/zipformer_ctc/scaling.py deleted file mode 120000 index 2428b74b9..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py b/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py deleted file mode 120000 index b8b8ba432..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/subsampling.py b/egs/librispeech/ASR/zipformer_ctc/subsampling.py deleted file mode 120000 index 6fee09e58..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py deleted file mode 100755 index f40344357..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ /dev/null @@ -1,1135 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - export CUDA_VISIBLE_DEVICES="0,1,2,3" - ./zipformer_ctc/train.py \ - --exp-dir ./zipformer_ctc/exp \ - --world-size 4 \ - --full-libri 1 \ - --max-duration 500 \ - --num-epochs 30 -""" - -import argparse -import copy -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from decoder import Decoder -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import CTCModel -from optim import Eden, LRScheduler, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--num-decoder-layers", - type=int, - default=6, - help="""Number of decoder layer of transformer decoder. - Setting this to 0 will not create the decoder at all (pure CTC model) - """, - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.8, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - use_feat_batchnorm: Normalization for the input features, can be a - boolean indicating whether to do batch - normalization, or a float which means just scaling - the input features with this float value. - If given a float value, we will remove batchnorm - layer in `ConvolutionModule` as well. - - - attention_dim: Hidden dim for multi-head attention model. - - - head: Number of heads of multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - - weight_decay: The weight_decay for the optimizer. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "frame_shift_ms": 10.0, - "allowed_excess_duration_ratio": 0.1, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - # parameters for loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - # parameters for decoding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - num_layers=params.num_decoder_layers, - num_classes=params.vocab_size, - d_model=int(params.encoder_dims.split(",")[-1]), - ) - return decoder - - -def get_ctc_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - - model = CTCModel( - encoder=encoder, - decoder=decoder, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: BpeCtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - with torch.set_grad_enabled(is_training): - ctc_loss, att_loss, tot_frames = model( - feature, - feature_lens, - supervisions, - graph_compiler, - subsampling_factor=params.subsampling_factor, - beam_size=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - info = MetricsTracker() - info["frames"] = tot_frames - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - assert loss.requires_grad == is_training, f"{loss.requires_grad} != {is_training}" - info["loss"] = loss.detach().cpu().item() - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = supervisions["num_frames"].sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() - ) - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: BpeCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: BpeCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - params.vocab_size = num_classes - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - logging.info("About to create model") - - model = get_ctc_model(params) - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 2**22 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - librispeech = LibriSpeechAsrDataModule(args) - - if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() - else: - train_cuts = librispeech.train_clean_100_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - return 1.0 <= c.duration <= 25.0 - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - train_dl = librispeech.train_dataloaders(train_cuts) - - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: BpeCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/zipformer_ctc/transformer.py b/egs/librispeech/ASR/zipformer_ctc/transformer.py deleted file mode 120000 index 4c890cf29..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/transformer.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer_ctc/zipformer.py b/egs/librispeech/ASR/zipformer_ctc/zipformer.py deleted file mode 120000 index 79b076556..000000000 --- a/egs/librispeech/ASR/zipformer_ctc/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless7/zipformer.py \ No newline at end of file