diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py index 38aff8834..bf4b2c248 100755 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/decode.py @@ -18,36 +18,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless2/decode.py \ +./vq_pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./vq_pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless2/decode.py \ +./vq_pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./vq_pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless2/decode.py \ +./vq_pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./vq_pruned_transducer_stateless2/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search -./pruned_transducer_stateless2/decode.py \ +./vq_pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./vq_pruned_transducer_stateless2/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ --beam 4 \ @@ -124,7 +124,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="vq_pruned_transducer_stateless2/exp", help="The experiment dir", ) diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/train.py index 80617847a..d976c5533 100755 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/train.py @@ -19,27 +19,19 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3" +export CUDA_VISIBLE_DEVICES="0,1,2" -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ +./vq_pruned_transducer_stateless2/train.py \ + --world-size 3 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ + --manifest-dir data/quantizer-finetuned_hubert_xtralarge-36layer-f0238f68-bytes_per_frame-8/ \ + --codebook-loss-scale 0.1 \ + --num-codebooks=8 \ + --exp-dir vq_pruned_transducer_stateless2/exp \ + --full-libri 0 \ --max-duration 300 -# For mix precision training: - -./pruned_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 0 \ - --use_fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 550 - """ @@ -60,9 +52,10 @@ from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut +from lhotse.cut import Cut, MonoCut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from lhotse.dataset.collation import collate_custom_field from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -138,7 +131,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="vq_pruned_transducer_stateless2/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -262,6 +255,23 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--codebook-loss-scale", + type=float, + default=0.1, + help="""The weight of code book loss. + Note: Currently rate of ctc_loss + rate of att_loss = 1.0 + codebook_loss is independent with them. + """, + ) + + parser.add_argument( + "--num-codebooks", + type=int, + default=8, + help="number of code books", + ) + return parser @@ -333,6 +343,8 @@ def get_params() -> AttributeDict: # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), + "extra_output_layer": 5, # 0-based index + "num_codebooks": 8, } ) @@ -348,6 +360,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + extra_output_layer=params.extra_output_layer, ) return encoder @@ -385,6 +398,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, + num_codebooks=params.num_codebooks, ) return model @@ -539,8 +553,20 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) + if is_training: + cuts = batch["supervisions"]["cut"] + # -100 is identical to ignore_value in CE loss computation. + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] + codebook_indices, codebook_indices_lens = collate_custom_field( + cuts_pre_mixed, "codebook_indices", pad_value=-100 + ) + codebook_indices = codebook_indices.to(device) + else: + codebook_indices = None with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, codebook_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -548,6 +574,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + codebook_indices=codebook_indices, ) # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid @@ -562,6 +589,9 @@ def compute_loss( params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss ) + if is_training: + assert codebook_loss is not None + loss += params.codebook_loss_scale * codebook_loss assert loss.requires_grad == is_training @@ -576,6 +606,8 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + if is_training: + info["codebook_loss"] = codebook_loss.detach().cpu().item() return loss, info @@ -835,7 +867,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts)