diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 20a8f7b3a..ee2f31483 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -17,7 +17,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse import logging from collections import defaultdict @@ -42,6 +41,7 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -100,7 +100,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. @@ -122,15 +122,35 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The LM dir. + It should contain either G_3_gram.pt or G_3_gram.fst.txt + """, + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_char"), - "lm_dir": Path("data/lm"), # parameters for conformer "subsampling_factor": 4, "feature_dim": 80, @@ -146,6 +166,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params @@ -154,9 +175,10 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], batch: dict, - word_table: k2.SymbolTable, + lexicon: Lexicon, sos_id: int, eos_id: int, ) -> Dict[str, List[List[int]]]: @@ -183,13 +205,15 @@ def decode_one_batch( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - word_table: - The word symbol table. + lexicon: + It contains the token symbol table and the word symbol table. sos_id: The token ID of the SOS. eos_id: @@ -198,16 +222,20 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = HLG.device + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) supervision_segments = torch.stack( ( @@ -218,9 +246,16 @@ def decode_one_batch( 1, ).to(torch.int32) + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + decoding_graph = H + lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -229,18 +264,37 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + key = "ctc-decoding" + hyps = [[lexicon.token_table[i] for i in ids] for ids in token_ids] + return {key: hyps} + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons # as HLG decoding is faster and the oracle WER - # is slightly worse than that of rescored lattices. - return nbest_oracle( + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( lattice=lattice, num_paths=params.num_paths, ref_texts=supervisions["text"], - word_table=word_table, - scale=params.lattice_score_scale, + word_table=lexicon.word_table, + nbest_scale=params.nbest_scale, + oov="", ) + hyps = get_texts(best_path) + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} if params.method in ["1best", "nbest"]: if params.method == "1best": @@ -253,12 +307,12 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] return {key: hyps} assert params.method == "attention-decoder" @@ -271,13 +325,14 @@ def decode_one_batch( memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, eos_id=eos_id, - scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps return ans @@ -285,8 +340,9 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, - word_table: k2.SymbolTable, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + lexicon: Lexicon, sos_id: int, eos_id: int, ) -> Dict[str, List[Tuple[List[int], List[int]]]]: @@ -300,9 +356,11 @@ def decode_dataset( model: The neural model. HLG: - The decoding graph. - word_table: - It is the word symbol table. + The decoding graph. Used when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + lexicon: + It contains the token symbol table and the word symbol table. sos_id: The token ID for SOS. eos_id: @@ -331,14 +389,16 @@ def decode_dataset( params=params, model=model, HLG=HLG, + H=H, batch=batch, - word_table=word_table, + lexicon=lexicon, sos_id=sos_id, eos_id=eos_id, ) for lm_scale, hyps in hyps_dict.items(): this_batch = [] + assert len(hyps) == len(texts) for hyp_words, ref_text in zip(hyps, texts): ref_words = ref_text.split() this_batch.append((ref_words, hyp_words)) @@ -411,6 +471,9 @@ def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) params = get_params() params.update(vars(args)) @@ -438,14 +501,22 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + else: + H = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() model = Conformer( num_features=params.feature_dim, @@ -468,7 +539,8 @@ def main(): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") @@ -483,12 +555,7 @@ def main(): logging.info(f"Number of model parameters: {num_param}") aishell = AishellAsrDataModule(args) - # CAUTION: `test_sets` is for displaying only. - # If you want to skip test-clean, you have to skip - # it inside the for loop. That is, use - # - # if test_set == 'test-clean': continue - # + test_sets = ["test"] for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()): results_dict = decode_dataset( @@ -496,7 +563,8 @@ def main(): params=params, model=model, HLG=HLG, - word_table=lexicon.word_table, + H=H, + lexicon=lexicon, sos_id=sos_id, eos_id=eos_id, ) diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index 3c54fc42a..94367ed4e 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -16,16 +16,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AishellAsrDataModule @@ -43,7 +41,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + MetricsTracker, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -78,7 +78,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=50, + default=90, help="Number of epochs to train.", ) @@ -92,6 +92,35 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.7, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + return parser @@ -99,19 +128,13 @@ def get_params() -> AttributeDict: """Return a dict containing training parameters. All training related parameters that are not passed from the commandline - is saved in the variable `params`. + are saved in the variable `params`. Commandline options are merged into `params` after they are parsed, so you can also access them via `params`. Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - best_valid_loss: Best validation loss so far. It is used to select the model that has the lowest validation loss. It is updated during the training. @@ -136,9 +159,6 @@ def get_params() -> AttributeDict: - use_double_scores: It is used in k2.ctc_loss - - att_rate: The proportion of label smoothing loss, final loss will be - (1 - att_rate) * ctc_loss + att_rate * label_smoothing_loss - - subsampling_factor: The subsampling factor for the model. - feature_dim: The model input dim. It has to match the one used @@ -163,8 +183,6 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_char"), "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -177,7 +195,6 @@ def get_params() -> AttributeDict: "beam_size": 10, "reduction": "sum", "use_double_scores": True, - "att_rate": 0.7, # parameters for conformer "subsampling_factor": 4, "feature_dim": 80, @@ -190,6 +207,7 @@ def get_params() -> AttributeDict: "weight_decay": 1e-5, "lr_factor": 5.0, "warm_step": 36000, + "env_info": get_env_info(), } ) @@ -289,7 +307,7 @@ def compute_loss( batch: dict, graph_compiler: CharCtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[torch.Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -312,14 +330,14 @@ def compute_loss( """ device = graph_compiler.device feature = batch["inputs"] - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -348,36 +366,41 @@ def compute_loss( if params.att_rate != 0.0: with torch.set_grad_enabled(is_training): - if hasattr(model, "module"): - att_loss = model.module.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss else: loss = ctc_loss att_loss = torch.tensor([0]) - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - assert loss.requires_grad == is_training - return loss, ctc_loss.detach(), att_loss.detach() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info def compute_validation_loss( @@ -386,18 +409,16 @@ def compute_validation_loss( graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: +) -> MetricsTracker: """Run the validation process. The validation loss is saved in `params.valid_loss`. """ model.eval() - tot_loss = 0.0 - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -405,36 +426,17 @@ def compute_validation_loss( is_training=False, ) assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - - tot_ctc_loss += ctc_loss.detach().cpu().item() - tot_att_loss += att_loss.detach().cpu().item() - - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor( - [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames], - device=loss.device, - ) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_ctc_loss = s[1] - tot_att_loss = s[2] - tot_frames = s[3] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames - params.valid_ctc_loss = tot_ctc_loss / tot_frames - params.valid_att_loss = tot_att_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -473,18 +475,13 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 + tot_loss = MetricsTracker() - tot_frames = 0.0 # sum of frames over all batches - params.tot_loss = 0.0 - params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -492,6 +489,9 @@ def train_one_epoch( is_training=True, ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -500,75 +500,26 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - ctc_loss_cpu = ctc_loss.detach().cpu().item() - att_loss_cpu = att_loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_ctc_loss += ctc_loss_cpu - tot_att_loss += att_loss_cpu - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - tot_avg_loss = tot_loss / tot_frames - tot_avg_ctc_loss = tot_ctc_loss / tot_frames - tot_avg_att_loss = tot_att_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " - f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " - f"total avg att loss: {tot_avg_att_loss:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % params.log_interval == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_ctc_loss", - ctc_loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_att_loss", - att_loss_cpu / params.train_frames, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_ctc_loss", - tot_avg_ctc_loss, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_att_loss", - tot_avg_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - - tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + logging.info("Computing validation loss") + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -576,33 +527,14 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, " - f"valid ctc loss {params.valid_ctc_loss:.4f}," - f"valid att loss {params.valid_att_loss:.4f}," - f"valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_ctc_loss", - params.valid_ctc_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_att_loss", - params.valid_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, - params.batch_idx_train, + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train ) - params.train_loss = params.tot_loss / params.tot_frames - + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss @@ -729,7 +661,8 @@ def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() - + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) world_size = args.world_size assert world_size >= 1