diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index d51d5d4ef..47b38bdc9 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -121,7 +121,7 @@ class Transducer(nn.Module): logit_lengths=x_lens, target_lengths=y_lens, blank=blank_id, - reduction="mean", + reduction="sum", ) return loss diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index 0fd838c7c..b5dbe02e9 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -25,6 +25,7 @@ from shutil import copyfile from typing import Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn @@ -36,21 +37,15 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transducer.conformer import Conformer from transducer.decoder import Decoder +from transducer.joiner import Joiner +from transducer.model import Transducer from transducer.transformer import Noam -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool def get_parser(): @@ -107,22 +102,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--bpe-model", type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.8, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - """, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", ) parser.add_argument( @@ -178,16 +161,8 @@ def get_params() -> AttributeDict: - attention_dim: Hidden dim for multi-head attention model. - - head: Number of heads of multi-head attention model. - - num_decoder_layers: Number of decoder layer of transformer decoder. - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - weight_decay: The weight_decay for the optimizer. - warm_step: The warm_step for Noam optimizer. @@ -213,16 +188,9 @@ def get_params() -> AttributeDict: "vgg_frontend": False, "use_feat_batchnorm": True, # decoder params - "vocab_size": 500, # including blank "decoder_embedding_dim": 1024, - "blank_id": 0, - "sos_id": 1, "num_decoder_layers": 4, "decoder_hidden_dim": 512, - # parameters for loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, # parameters for Noam "weight_decay": 1e-6, "warm_step": 80000, @@ -262,6 +230,27 @@ def get_decoder_model(params: AttributeDict): return decoder +def get_joiner_model(params: AttributeDict): + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict): + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -352,8 +341,8 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: nn.Module, + sp: spm.SentencePieceProcessor, batch: dict, - graph_compiler: BpeCtcTrainingGraphCompiler, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ @@ -367,86 +356,35 @@ def compute_loss( batch: A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. is_training: True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. """ - device = graph_compiler.device + device = model.device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate != 0.0: - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) + loss = model(x=feature, x_lens=feature_lens, y=y) assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - info["loss"] = loss.detach().cpu().item() + # We use reduction="sum" in computing the loss. + # The displayed loss is the average loss over the batch + info["loss"] = loss.detach().cpu().item() / feature.size(0) return loss, info @@ -454,7 +392,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: nn.Module, - graph_compiler: BpeCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -467,8 +405,8 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, + sp=sp, batch=batch, - graph_compiler=graph_compiler, is_training=False, ) assert loss.requires_grad is False @@ -489,7 +427,7 @@ def train_one_epoch( params: AttributeDict, model: nn.Module, optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, @@ -508,8 +446,6 @@ def train_one_epoch( The model for training. optimizer: The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. train_dl: Dataloader for the training dataset. valid_dl: @@ -530,8 +466,8 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, + sp=sp, batch=batch, - graph_compiler=graph_compiler, is_training=True, ) # summary stats @@ -567,7 +503,7 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - graph_compiler=graph_compiler, + sp=sp, valid_dl=valid_dl, world_size=world_size, ) @@ -606,50 +542,37 @@ def run(rank, world_size, args): setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") - logging.info(params) if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None - model = get_encoder_model(params) - model = get_decoder_model(params) - print(model) - return - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) + logging.info(f"Device: {device}") - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) + model = get_transducer_model(params) checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) if world_size > 1: + logging.info("Using DDP") model = DDP(model, device_ids=[rank]) + model.device = device optimizer = Noam( model.parameters(), @@ -659,7 +582,8 @@ def run(rank, world_size, args): weight_decay=params.weight_decay, ) - if checkpoints: + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) librispeech = LibriSpeechAsrDataModule(args) @@ -678,7 +602,7 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - graph_compiler=graph_compiler, + sp=sp, params=params, ) @@ -701,7 +625,7 @@ def run(rank, world_size, args): params=params, model=model, optimizer=optimizer, - graph_compiler=graph_compiler, + sp=sp, train_dl=train_dl, valid_dl=valid_dl, tb_writer=tb_writer, @@ -726,7 +650,7 @@ def scan_pessimistic_batches_for_oom( model: nn.Module, train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, + sp: spm.SentencePieceProcessor, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -742,8 +666,8 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, + sp=sp, batch=batch, - graph_compiler=graph_compiler, is_training=True, ) loss.backward() @@ -766,7 +690,6 @@ def main(): LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) world_size = args.world_size assert world_size >= 1