From a4722dd7c09b0caa4a7062b0bd55618af752c3f0 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 2 Dec 2021 17:16:48 +0800 Subject: [PATCH 1/2] training with coodbook loss --- .../ASR/conformer_ctc/conformer.py | 60 +++++++++++++++++++ egs/librispeech/ASR/conformer_ctc/decode.py | 15 ++++- egs/librispeech/ASR/conformer_ctc/train.py | 35 +++++++++++ icefall/utils.py | 6 +- 4 files changed, 112 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index b19b94db1..73c60b2d0 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -25,6 +25,64 @@ from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask +class CodeIndicesNet(nn.Module): + """Used to compute codebook indices and codebook loss.""" + + def __init__( + self, + d_model=512, + quantizer_dim=512, + num_codebooks=4, + ): + """ + Args: + d_model: + The dimention of memory embeddings(input). + quantizer_dim: + The dimention of quantizer, i.e. num-classes of CE loss; + num_codebooks: + Number of codebooks used, i.e. number of CE losses actually used. + """ + + super().__init__() + self.linear1 = nn.Linear(d_model, num_codebooks * quantizer_dim) + # Default reduction is 'mean' + self.ce = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum") + self.num_codebooks = num_codebooks + self.quantizer_dim = quantizer_dim + + def forward(self, memory): + """ + Args: + memory: + memory embeddings, with shape[T, N, C] + output: + shape [N, T, num_codebooks*quantizer_dim] + """ + x = self.linear1(memory) + return x + + def loss(self, memory: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + memory: + memory embeddings, with shape[T, N, C] + target: + codebook indices, with shape[N, T, num_codebooks] + + output: + codebook loss; + actually it's the sum of num_codebooks CE losses + """ + + memory = memory.transpose(0, 1) # T, N, C --> N, T, C + x = self.forward(memory) + x = x.reshape(-1, self.quantizer_dim) + target = target.reshape(-1) + ret = self.ce(x, target) + return ret + + class Conformer(Transformer): """ Args: @@ -92,6 +150,8 @@ class Conformer(Transformer): # and throws an error without this change. self.after_norm = identity + self.cdidxnet = CodeIndicesNet() + def run_encoder( self, x: Tensor, supervisions: Optional[Supervisions] = None ) -> Tuple[Tensor, Optional[Tensor]]: diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 63aed9358..f353adf37 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -499,14 +499,20 @@ def save_results( enable_log = True test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + recog_path = ( + params.exp_dir + / f"epoch-{params.epoch}-avg-{params.avg}-recogs-{test_set_name}-{key}.txt" + ) store_transcripts(filename=recog_path, texts=results) if enable_log: logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + errs_filename = ( + params.exp_dir + / f"epoch-{params.epoch}-avg-{params.avg}-errs-{test_set_name}-{key}.txt" + ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=enable_log @@ -519,7 +525,10 @@ def save_results( ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + errs_info = ( + params.exp_dir + / f"epoch-{params.epoch}-avg-{params.avg}-wer-summary-{test_set_name}.txt" + ) with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index ec9b0b7c2..ceafe6cbe 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -31,6 +31,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed +from lhotse.dataset.collation import collate_custom_field from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ @@ -124,6 +125,16 @@ def get_parser(): """, ) + parser.add_argument( + "--codebook-weight", + type=float, + default=0.1, + help="""The weight of code book loss. + Note: Currently rate of ctc_loss + rate of att_loss = 1.0 + codebook_weight is independent with previous two. + """, + ) + parser.add_argument( "--lr-factor", type=float, @@ -394,6 +405,27 @@ def compute_loss( eos_id=graph_compiler.eos_id, ) loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + + if params.codebook_weight != 0.0: + + cuts = batch["supervisions"]["cut"] + # -100 is identical to ignore_value in CE loss computation. + codebook_indices, codebook_indices_lens = collate_custom_field( + cuts, "codebook_indices", pad_value=-100 + ) + + assert ( + codebook_indices.shape[0] == encoder_memory.shape[1] + ) # N: batch_size + assert ( + codebook_indices.shape[1] == encoder_memory.shape[0] + ) # T: num frames + codebook_indices = codebook_indices.to(encoder_memory.device).long() + codebook_loss = mmodel.cdidxnet.loss( + encoder_memory, target=codebook_indices + ) + + loss += params.codebook_weight * codebook_loss else: loss = ctc_loss att_loss = torch.tensor([0]) @@ -406,6 +438,9 @@ def compute_loss( if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() + if params.codebook_weight != 0.0: + info["codebook_loss"] = cdidx_loss.detach().cpu().item() + info["loss"] = loss.detach().cpu().item() return loss, info diff --git a/icefall/utils.py b/icefall/utils.py index ba9436fa4..e3ac196cf 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -126,7 +126,11 @@ def setup_logger( level = logging.CRITICAL logging.basicConfig( - filename=log_filename, format=formatter, level=level, filemode="w" + filename=log_filename, + format=formatter, + level=level, + filemode="w", + force=True, ) if use_console: console = logging.StreamHandler() From 54bcc167e17e0edb26f9d32a1f35510fbf323fee Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 2 Dec 2021 17:46:14 +0800 Subject: [PATCH 2/2] fix ci --- egs/librispeech/ASR/conformer_ctc/decode.py | 9 ++++++--- egs/librispeech/ASR/conformer_ctc/train.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index f353adf37..ed2da7b76 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -501,7 +501,8 @@ def save_results( for key, results in results_dict.items(): recog_path = ( params.exp_dir - / f"epoch-{params.epoch}-avg-{params.avg}-recogs-{test_set_name}-{key}.txt" + / f"epoch-{params.epoch}-avg-{params.avg}- \ + recogs-{test_set_name}-{key}.txt" ) store_transcripts(filename=recog_path, texts=results) if enable_log: @@ -511,7 +512,8 @@ def save_results( # ref/hyp pairs. errs_filename = ( params.exp_dir - / f"epoch-{params.epoch}-avg-{params.avg}-errs-{test_set_name}-{key}.txt" + / f"epoch-{params.epoch}-avg-{params.avg}- \ + errs-{test_set_name}-{key}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -527,7 +529,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( params.exp_dir - / f"epoch-{params.epoch}-avg-{params.avg}-wer-summary-{test_set_name}.txt" + / f"epoch-{params.epoch}-avg-{params.avg}- \ + wer-summary-{test_set_name}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index ceafe6cbe..8a9bcfa8b 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -439,7 +439,7 @@ def compute_loss( info["att_loss"] = att_loss.detach().cpu().item() if params.codebook_weight != 0.0: - info["codebook_loss"] = cdidx_loss.detach().cpu().item() + info["codebook_loss"] = codebook_loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()