From c381b491f1cfa1e3fb6d505f66c17f9eccc1c1f2 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 4 Jun 2022 21:01:07 +0800 Subject: [PATCH] different weight for masked/unmasked region --- .../ASR/pruned_transducer_stateless6/model.py | 17 ++++++++++------ .../ASR/pruned_transducer_stateless6/train.py | 20 ++++++++++++++++++- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 5102f357e..305049d69 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -23,7 +23,7 @@ from scaling import ScaledLinear from icefall.utils import add_sos -from quantization.prediction import JointCodebookLoss +from multi_quantization.prediction import JointCodebookLoss class Transducer(nn.Module): @@ -41,6 +41,8 @@ class Transducer(nn.Module): joiner_dim: int, vocab_size: int, num_codebooks: int = 0, + masked_scale: float = 1.0, + unmasked_scale: float = 1.0, ): """ Args: @@ -60,6 +62,10 @@ class Transducer(nn.Module): contains unnormalized probs, i.e., not processed by log-softmax. num_codebooks: Used by distillation loss. + masked_scale: + scale of codebook loss of masked area. + unmasked_scale: + scale of codebook loss of unmasked area. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -79,6 +85,8 @@ class Transducer(nn.Module): num_codebooks=num_codebooks, reduction="none", ) + self.masked_scale = masked_scale + self.unmasked_scale = unmasked_scale def forward( self, @@ -91,7 +99,6 @@ class Transducer(nn.Module): warmup: float = 1.0, codebook_indexes: torch.Tensor = None, time_masked_area: torch.Tensor = None, - masked_scale: float = 1.0, ) -> torch.Tensor: """ Args: @@ -119,9 +126,6 @@ class Transducer(nn.Module): codebook_indexes extracted from a teacher model. time_masked_area: masked area by SpecAugment, 1 represents masked. - masked_scale: - scale of codebook loss of masked area. - the unmasked_scale = 1 - masked_scale Returns: Return the transducer loss. @@ -162,7 +166,8 @@ class Transducer(nn.Module): masked_loss = (time_masked_area * codebook_loss).sum() unmasked_loss = (~time_masked_area * codebook_loss).sum() codebook_loss = ( - masked_scale * masked_loss + (1 - masked_scale) * unmasked_loss + self.masked_scale * masked_loss + + self.unmasked_scale * unmasked_loss ) else: # when codebook index is not available. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index dbf87ff48..3cec4326e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -177,6 +177,18 @@ def get_parser(): changed.""", ) + parser.add_argument( + "--masked-scale", + type=float, + default=1.0, + ) + + parser.add_argument( + "--unmasked-scale", + type=float, + default=1.0, + ) + parser.add_argument( "--lr-batches", type=float, @@ -378,6 +390,8 @@ def get_params() -> AttributeDict: # two successive codebook_index are concatenated together. # Detailed in function Transducer::concat_sucessive_codebook_indexes. "num_codebooks": 16, # used to construct distillation loss + "masked_scale": 1.0, + "unmasked_scale": 1.0, } ) @@ -436,6 +450,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: num_codebooks=params.num_codebooks if params.enable_distiallation else 0, + masked_scale=params.masked_scale, + unmasked_scale=params.unmasked_scale, ) return model @@ -1090,7 +1106,9 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) + args.exp_dir = Path( + f"{args.exp_dir}-masked_scale-{args.masked_scale}-un-{args.unmasked_scale}-{args.spec_aug_max_frames_mask_fraction}" + ) world_size = args.world_size assert world_size >= 1