mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
different weight for masked/unmasked region
This commit is contained in:
parent
90024c308f
commit
c381b491f1
@ -23,7 +23,7 @@ from scaling import ScaledLinear
|
|||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
from quantization.prediction import JointCodebookLoss
|
from multi_quantization.prediction import JointCodebookLoss
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -41,6 +41,8 @@ class Transducer(nn.Module):
|
|||||||
joiner_dim: int,
|
joiner_dim: int,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
num_codebooks: int = 0,
|
num_codebooks: int = 0,
|
||||||
|
masked_scale: float = 1.0,
|
||||||
|
unmasked_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -60,6 +62,10 @@ class Transducer(nn.Module):
|
|||||||
contains unnormalized probs, i.e., not processed by log-softmax.
|
contains unnormalized probs, i.e., not processed by log-softmax.
|
||||||
num_codebooks:
|
num_codebooks:
|
||||||
Used by distillation loss.
|
Used by distillation loss.
|
||||||
|
masked_scale:
|
||||||
|
scale of codebook loss of masked area.
|
||||||
|
unmasked_scale:
|
||||||
|
scale of codebook loss of unmasked area.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
@ -79,6 +85,8 @@ class Transducer(nn.Module):
|
|||||||
num_codebooks=num_codebooks,
|
num_codebooks=num_codebooks,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
)
|
)
|
||||||
|
self.masked_scale = masked_scale
|
||||||
|
self.unmasked_scale = unmasked_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -91,7 +99,6 @@ class Transducer(nn.Module):
|
|||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
codebook_indexes: torch.Tensor = None,
|
codebook_indexes: torch.Tensor = None,
|
||||||
time_masked_area: torch.Tensor = None,
|
time_masked_area: torch.Tensor = None,
|
||||||
masked_scale: float = 1.0,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -119,9 +126,6 @@ class Transducer(nn.Module):
|
|||||||
codebook_indexes extracted from a teacher model.
|
codebook_indexes extracted from a teacher model.
|
||||||
time_masked_area:
|
time_masked_area:
|
||||||
masked area by SpecAugment, 1 represents masked.
|
masked area by SpecAugment, 1 represents masked.
|
||||||
masked_scale:
|
|
||||||
scale of codebook loss of masked area.
|
|
||||||
the unmasked_scale = 1 - masked_scale
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -162,7 +166,8 @@ class Transducer(nn.Module):
|
|||||||
masked_loss = (time_masked_area * codebook_loss).sum()
|
masked_loss = (time_masked_area * codebook_loss).sum()
|
||||||
unmasked_loss = (~time_masked_area * codebook_loss).sum()
|
unmasked_loss = (~time_masked_area * codebook_loss).sum()
|
||||||
codebook_loss = (
|
codebook_loss = (
|
||||||
masked_scale * masked_loss + (1 - masked_scale) * unmasked_loss
|
self.masked_scale * masked_loss
|
||||||
|
+ self.unmasked_scale * unmasked_loss
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# when codebook index is not available.
|
# when codebook index is not available.
|
||||||
|
@ -177,6 +177,18 @@ def get_parser():
|
|||||||
changed.""",
|
changed.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--masked-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--unmasked-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-batches",
|
"--lr-batches",
|
||||||
type=float,
|
type=float,
|
||||||
@ -378,6 +390,8 @@ def get_params() -> AttributeDict:
|
|||||||
# two successive codebook_index are concatenated together.
|
# two successive codebook_index are concatenated together.
|
||||||
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
|
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
|
||||||
"num_codebooks": 16, # used to construct distillation loss
|
"num_codebooks": 16, # used to construct distillation loss
|
||||||
|
"masked_scale": 1.0,
|
||||||
|
"unmasked_scale": 1.0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -436,6 +450,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
num_codebooks=params.num_codebooks
|
num_codebooks=params.num_codebooks
|
||||||
if params.enable_distiallation
|
if params.enable_distiallation
|
||||||
else 0,
|
else 0,
|
||||||
|
masked_scale=params.masked_scale,
|
||||||
|
unmasked_scale=params.unmasked_scale,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -1090,7 +1106,9 @@ def main():
|
|||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(
|
||||||
|
f"{args.exp_dir}-masked_scale-{args.masked_scale}-un-{args.unmasked_scale}-{args.spec_aug_max_frames_mask_fraction}"
|
||||||
|
)
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user