diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index b0dbe72ad..a8bf2dc06 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -85,7 +85,7 @@ def get_parser(): "--start-epoch", type=int, default=0, - help="""Resume training from from this epoch. + help="""Resume training from this epoch. If it is positive, it will load checkpoint from conformer_ctc/exp/epoch-{start_epoch-1}.pt """, diff --git a/egs/librispeech/ASR/conformer_mmi_phone/conformer.py b/egs/librispeech/ASR/conformer_mmi_phone/conformer.py index 08287d686..71b8f227d 100644 --- a/egs/librispeech/ASR/conformer_mmi_phone/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi_phone/conformer.py @@ -57,7 +57,7 @@ class Conformer(Transformer): normalize_before: bool = True, vgg_frontend: bool = False, is_espnet_structure: bool = False, - mmi_loss: bool = True, + is_bpe: bool = True, use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( @@ -72,7 +72,7 @@ class Conformer(Transformer): dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend, - mmi_loss=mmi_loss, + is_bpe=is_bpe, use_feat_batchnorm=use_feat_batchnorm, ) diff --git a/egs/librispeech/ASR/conformer_mmi_phone/train.py b/egs/librispeech/ASR/conformer_mmi_phone/train.py index 402c4a2bb..54e6a5fcd 100755 --- a/egs/librispeech/ASR/conformer_mmi_phone/train.py +++ b/egs/librispeech/ASR/conformer_mmi_phone/train.py @@ -1,4 +1,20 @@ #!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import logging @@ -59,6 +75,23 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--num-epochs", + type=int, + default=50, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + return parser @@ -88,11 +121,6 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - num_epochs: Number of epochs to train. - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -120,8 +148,6 @@ def get_params() -> AttributeDict: "feature_dim": 80, "weight_decay": 1e-6, "subsampling_factor": 4, - "start_epoch": 0, - "num_epochs": 50, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -130,13 +156,14 @@ def get_params() -> AttributeDict: "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, + "beam_size": 10, "use_pruned_intersect": False, "den_scale": 1.0, # - "att_rate": 0, # If not zero, use attention decoder + "att_rate": 0.7, # If not zero, use attention decoder "attention_dim": 512, "nhead": 8, - "num_decoder_layers": 0, + "num_decoder_layers": 6, "is_espnet_structure": True, "use_feat_batchnorm": True, "lr_factor": 5.0, @@ -288,15 +315,14 @@ def compute_loss( loss_fn = LFMMILoss( graph_compiler=graph_compiler, + output_beam=params.beam_size, den_scale=params.den_scale, use_pruned_intersect=params.use_pruned_intersect, ) mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) - assert params.att_rate == 0 if params.att_rate != 0.0: - # TODO: not working token_ids = graph_compiler.texts_to_ids(texts) with torch.set_grad_enabled(is_training): if hasattr(model, "module"): @@ -304,16 +330,16 @@ def compute_loss( encoder_memory, memory_mask, token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, + sos_id=params.sos_id, + eos_id=params.eos_id, ) else: att_loss = model.decoder_forward( encoder_memory, memory_mask, token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, + sos_id=params.sos_id, + eos_id=params.eos_id, ) loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss else: @@ -587,7 +613,6 @@ def run(rank, world_size, args): setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") - logging.info(params) if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") @@ -617,12 +642,19 @@ def run(rank, world_size, args): subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=False, + is_bpe=False, is_espnet_structure=params.is_espnet_structure, use_feat_batchnorm=params.use_feat_batchnorm, ) + assert model.decoder_num_class == num_classes + 1 + + params.sos_id = num_classes + params.eos_id = num_classes checkpoints = load_checkpoint_if_available(params=params, model=model) + logging.info(params) + model.to(device) if world_size > 1: model = DDP(model, device_ids=[rank]) diff --git a/egs/librispeech/ASR/conformer_mmi_phone/transformer.py b/egs/librispeech/ASR/conformer_mmi_phone/transformer.py index 74e61b645..a7fad9438 100644 --- a/egs/librispeech/ASR/conformer_mmi_phone/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi_phone/transformer.py @@ -42,7 +42,7 @@ class Transformer(nn.Module): dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, - mmi_loss: bool = True, + is_bpe: bool = True, use_feat_batchnorm: bool = False, ) -> None: """ @@ -71,7 +71,9 @@ class Transformer(nn.Module): If True, use pre-layer norm; False to use post-layer norm. vgg_frontend: True to use vgg style frontend for subsampling. - mmi_loss: + is_bpe: + True if the modeling unit is word pieces which has already included + SOS/EOS IDs. use_feat_batchnorm: True to use batchnorm for the input layer. """ @@ -123,7 +125,7 @@ class Transformer(nn.Module): ) if num_decoder_layers > 0: - if mmi_loss: + if is_bpe is False: self.decoder_num_class = ( self.num_classes + 1 ) # +1 for the sos/eos symbol diff --git a/icefall/mmi.py b/icefall/mmi.py index 6de1ab7b0..93fdc3bce 100644 --- a/icefall/mmi.py +++ b/icefall/mmi.py @@ -11,6 +11,7 @@ def _compute_mmi_loss_exact_optimized( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], graph_compiler: MmiTrainingGraphCompiler, + output_beam: float, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -79,7 +80,7 @@ def _compute_mmi_loss_exact_optimized( num_den_lats = k2.intersect_dense( num_den_reordered_graphs, dense_fsa_vec, - output_beam=8.0, + output_beam=output_beam, a_to_b_map=a_to_b_map, ) @@ -99,6 +100,7 @@ def _compute_mmi_loss_exact_non_optimized( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], graph_compiler: MmiTrainingGraphCompiler, + output_beam: float, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -112,9 +114,12 @@ def _compute_mmi_loss_exact_non_optimized( """ num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) - # TODO: pass output_beam as function argument - num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0) - den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=8.0) + num_lats = k2.intersect_dense( + num_graphs, dense_fsa_vec, output_beam=output_beam + ) + den_lats = k2.intersect_dense( + den_graphs, dense_fsa_vec, output_beam=output_beam + ) num_tot_scores = num_lats.get_tot_scores( log_semiring=True, use_double_scores=True @@ -134,6 +139,7 @@ def _compute_mmi_loss_pruned( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], graph_compiler: MmiTrainingGraphCompiler, + output_beam: float, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -148,7 +154,9 @@ def _compute_mmi_loss_pruned( """ num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False) - num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) + num_lats = k2.intersect_dense( + num_graphs, dense_fsa_vec, output_beam=output_beam + ) # the values for search_beam/output_beam/min_active_states/max_active_states # are not tuned. You may want to tune them. @@ -156,7 +164,7 @@ def _compute_mmi_loss_pruned( den_graphs, dense_fsa_vec, search_beam=20.0, - output_beam=8.0, + output_beam=output_beam, min_active_states=30, max_active_states=10000, ) @@ -185,13 +193,15 @@ class LFMMILoss(nn.Module): def __init__( self, graph_compiler: MmiTrainingGraphCompiler, + output_beam: float, use_pruned_intersect: bool = False, den_scale: float = 1.0, ): super().__init__() self.graph_compiler = graph_compiler - self.den_scale = den_scale + self.output_beam = output_beam self.use_pruned_intersect = use_pruned_intersect + self.den_scale = den_scale def forward( self, @@ -218,5 +228,6 @@ class LFMMILoss(nn.Module): dense_fsa_vec=dense_fsa_vec, texts=texts, graph_compiler=self.graph_compiler, + output_beam=self.output_beam, den_scale=self.den_scale, ) diff --git a/test/test_mmi_graph_compiler.py b/test/test_mmi_graph_compiler.py index 80a1d9722..c36bfb045 100755 --- a/test/test_mmi_graph_compiler.py +++ b/test/test_mmi_graph_compiler.py @@ -155,6 +155,12 @@ def mmi_graph_compiler_test(): ) den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo") + texts = ["cat at cat", "ac at ca"] + token_ids = graph_compiler.texts_to_ids(texts) + # c a t a t c a t a c a t SPN + expected_token_ids = [[3, 2, 4, 2, 4, 3, 2, 4], [2, 3, 2, 4, 1]] + assert token_ids == expected_token_ids + def test_main(): generate_test_data()