mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 12:42:20 +00:00
set default value of cr_loss_masked_scale to 1.0
This commit is contained in:
parent
07d6b12364
commit
cf796eefed
@ -189,7 +189,7 @@ class AsrModel(nn.Module):
|
|||||||
targets: torch.Tensor,
|
targets: torch.Tensor,
|
||||||
target_lengths: torch.Tensor,
|
target_lengths: torch.Tensor,
|
||||||
time_mask: Optional[torch.Tensor] = None,
|
time_mask: Optional[torch.Tensor] = None,
|
||||||
cr_loss_masked_scale: float = 3.0,
|
cr_loss_masked_scale: float = 1.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Compute CTC loss with consistency regularization loss.
|
"""Compute CTC loss with consistency regularization loss.
|
||||||
Args:
|
Args:
|
||||||
@ -359,7 +359,7 @@ class AsrModel(nn.Module):
|
|||||||
spec_augment: Optional[SpecAugment] = None,
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
supervision_segments: Optional[torch.Tensor] = None,
|
supervision_segments: Optional[torch.Tensor] = None,
|
||||||
time_warp_factor: Optional[int] = 80,
|
time_warp_factor: Optional[int] = 80,
|
||||||
cr_loss_masked_scale: float = 3.0,
|
cr_loss_masked_scale: float = 1.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user