set default value of cr_loss_masked_scale to 1.0

This commit is contained in:
yaozengwei 2024-09-06 10:32:36 +08:00
parent 07d6b12364
commit cf796eefed

View File

@ -189,7 +189,7 @@ class AsrModel(nn.Module):
targets: torch.Tensor,
target_lengths: torch.Tensor,
time_mask: Optional[torch.Tensor] = None,
cr_loss_masked_scale: float = 3.0,
cr_loss_masked_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
@ -359,7 +359,7 @@ class AsrModel(nn.Module):
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
cr_loss_masked_scale: float = 3.0,
cr_loss_masked_scale: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args: