different weight for masked/unmasked region

This commit is contained in:
Guo Liyong 2022-06-04 21:01:07 +08:00
parent 90024c308f
commit c381b491f1
2 changed files with 30 additions and 7 deletions

View File

@ -23,7 +23,7 @@ from scaling import ScaledLinear
from icefall.utils import add_sos
from quantization.prediction import JointCodebookLoss
from multi_quantization.prediction import JointCodebookLoss
class Transducer(nn.Module):
@ -41,6 +41,8 @@ class Transducer(nn.Module):
joiner_dim: int,
vocab_size: int,
num_codebooks: int = 0,
masked_scale: float = 1.0,
unmasked_scale: float = 1.0,
):
"""
Args:
@ -60,6 +62,10 @@ class Transducer(nn.Module):
contains unnormalized probs, i.e., not processed by log-softmax.
num_codebooks:
Used by distillation loss.
masked_scale:
scale of codebook loss of masked area.
unmasked_scale:
scale of codebook loss of unmasked area.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -79,6 +85,8 @@ class Transducer(nn.Module):
num_codebooks=num_codebooks,
reduction="none",
)
self.masked_scale = masked_scale
self.unmasked_scale = unmasked_scale
def forward(
self,
@ -91,7 +99,6 @@ class Transducer(nn.Module):
warmup: float = 1.0,
codebook_indexes: torch.Tensor = None,
time_masked_area: torch.Tensor = None,
masked_scale: float = 1.0,
) -> torch.Tensor:
"""
Args:
@ -119,9 +126,6 @@ class Transducer(nn.Module):
codebook_indexes extracted from a teacher model.
time_masked_area:
masked area by SpecAugment, 1 represents masked.
masked_scale:
scale of codebook loss of masked area.
the unmasked_scale = 1 - masked_scale
Returns:
Return the transducer loss.
@ -162,7 +166,8 @@ class Transducer(nn.Module):
masked_loss = (time_masked_area * codebook_loss).sum()
unmasked_loss = (~time_masked_area * codebook_loss).sum()
codebook_loss = (
masked_scale * masked_loss + (1 - masked_scale) * unmasked_loss
self.masked_scale * masked_loss
+ self.unmasked_scale * unmasked_loss
)
else:
# when codebook index is not available.

View File

@ -177,6 +177,18 @@ def get_parser():
changed.""",
)
parser.add_argument(
"--masked-scale",
type=float,
default=1.0,
)
parser.add_argument(
"--unmasked-scale",
type=float,
default=1.0,
)
parser.add_argument(
"--lr-batches",
type=float,
@ -378,6 +390,8 @@ def get_params() -> AttributeDict:
# two successive codebook_index are concatenated together.
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
"num_codebooks": 16, # used to construct distillation loss
"masked_scale": 1.0,
"unmasked_scale": 1.0,
}
)
@ -436,6 +450,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
num_codebooks=params.num_codebooks
if params.enable_distiallation
else 0,
masked_scale=params.masked_scale,
unmasked_scale=params.unmasked_scale,
)
return model
@ -1090,7 +1106,9 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.exp_dir = Path(
f"{args.exp_dir}-masked_scale-{args.masked_scale}-un-{args.unmasked_scale}-{args.spec_aug_max_frames_mask_fraction}"
)
world_size = args.world_size
assert world_size >= 1