mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
different weight for masked/unmasked region
This commit is contained in:
parent
90024c308f
commit
c381b491f1
@ -23,7 +23,7 @@ from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
from quantization.prediction import JointCodebookLoss
|
||||
from multi_quantization.prediction import JointCodebookLoss
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -41,6 +41,8 @@ class Transducer(nn.Module):
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
num_codebooks: int = 0,
|
||||
masked_scale: float = 1.0,
|
||||
unmasked_scale: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -60,6 +62,10 @@ class Transducer(nn.Module):
|
||||
contains unnormalized probs, i.e., not processed by log-softmax.
|
||||
num_codebooks:
|
||||
Used by distillation loss.
|
||||
masked_scale:
|
||||
scale of codebook loss of masked area.
|
||||
unmasked_scale:
|
||||
scale of codebook loss of unmasked area.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
@ -79,6 +85,8 @@ class Transducer(nn.Module):
|
||||
num_codebooks=num_codebooks,
|
||||
reduction="none",
|
||||
)
|
||||
self.masked_scale = masked_scale
|
||||
self.unmasked_scale = unmasked_scale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -91,7 +99,6 @@ class Transducer(nn.Module):
|
||||
warmup: float = 1.0,
|
||||
codebook_indexes: torch.Tensor = None,
|
||||
time_masked_area: torch.Tensor = None,
|
||||
masked_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -119,9 +126,6 @@ class Transducer(nn.Module):
|
||||
codebook_indexes extracted from a teacher model.
|
||||
time_masked_area:
|
||||
masked area by SpecAugment, 1 represents masked.
|
||||
masked_scale:
|
||||
scale of codebook loss of masked area.
|
||||
the unmasked_scale = 1 - masked_scale
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -162,7 +166,8 @@ class Transducer(nn.Module):
|
||||
masked_loss = (time_masked_area * codebook_loss).sum()
|
||||
unmasked_loss = (~time_masked_area * codebook_loss).sum()
|
||||
codebook_loss = (
|
||||
masked_scale * masked_loss + (1 - masked_scale) * unmasked_loss
|
||||
self.masked_scale * masked_loss
|
||||
+ self.unmasked_scale * unmasked_loss
|
||||
)
|
||||
else:
|
||||
# when codebook index is not available.
|
||||
|
@ -177,6 +177,18 @@ def get_parser():
|
||||
changed.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--masked-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--unmasked-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
@ -378,6 +390,8 @@ def get_params() -> AttributeDict:
|
||||
# two successive codebook_index are concatenated together.
|
||||
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
|
||||
"num_codebooks": 16, # used to construct distillation loss
|
||||
"masked_scale": 1.0,
|
||||
"unmasked_scale": 1.0,
|
||||
}
|
||||
)
|
||||
|
||||
@ -436,6 +450,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
num_codebooks=params.num_codebooks
|
||||
if params.enable_distiallation
|
||||
else 0,
|
||||
masked_scale=params.masked_scale,
|
||||
unmasked_scale=params.unmasked_scale,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -1090,7 +1106,9 @@ def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.exp_dir = Path(
|
||||
f"{args.exp_dir}-masked_scale-{args.masked_scale}-un-{args.unmasked_scale}-{args.spec_aug_max_frames_mask_fraction}"
|
||||
)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user