diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index c6063fade..09d452c68 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -38,6 +38,7 @@ from torch.utils.tensorboard import SummaryWriter from transformer import Noam from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist @@ -350,9 +351,15 @@ def compute_loss( supervisions, subsampling_factor=params.subsampling_factor ) - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -379,9 +386,7 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -514,9 +519,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -577,12 +580,27 @@ def run(rank, world_size, args): if torch.cuda.is_available(): device = torch.device("cuda", rank) - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) + if "lang_bpe" in params.lang_dir: + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in params.lang_dir: + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have 'lang_bpe' or 'lang_phone' " + f"in its name): {params.lang_dir}" + ) logging.info("About to create model") model = Conformer( @@ -600,7 +618,9 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + # Note: find_unused_parameters=True is needed in case we + # want to set params.att_rate = 0 (i.e. att decoder is not trained) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = Noam( model.parameters(), @@ -630,9 +650,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/train_phones.py b/egs/librispeech/ASR/conformer_ctc/train_phones.py deleted file mode 100755 index bed799ea5..000000000 --- a/egs/librispeech/ASR/conformer_ctc/train_phones.py +++ /dev/null @@ -1,735 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=35, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - conformer_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_phones", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - (disabled by default; not supported yet) - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - - attention_dim: Hidden dim for multi-head attention model. - - - head: Number of heads of multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - - weight_decay: The weight_decay for the optimizer. - - - lr_factor: The lr_factor for Noam optimizer. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "use_feat_batchnorm": True, - "attention_dim": 512, - "nhead": 8, - "num_decoder_layers": 6, - # parameters for loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - # parameters for Noam - "weight_decay": 1e-6, - "lr_factor": 3.0, - "warm_step": 80000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: BpeCtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(texts) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate != 0.0: - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: BpeCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(42) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = CtcTrainingGraphCompiler( - lexicon, - device=device, - ) - # Manually add the sos/eos ID with their default values - # from the BPE recipe which we're adapting here. - graph_compiler.sos_id = 1 - graph_compiler.eos_id = 1 - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - # Note: find_unused_parameters=True is needed because we're - # not training decoder at all by default (for now). - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - -# Workaround to see logs with some pytorch versions... -logging.info = logging.warning - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index 570ed7d7a..e2ff03f61 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -75,9 +75,7 @@ class CtcTrainingGraphCompiler(object): # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially # is False, so we add epsilon self-loops here - fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops( - transcript_fsa - ) + fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa) fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)