Update train.py to use torchaudio's RNN-T loss.

This commit is contained in:
Fangjun Kuang 2022-04-14 11:43:35 +08:00
parent ad69dbeedf
commit fd6416e6c1
2 changed files with 0 additions and 15 deletions

View File

@ -70,7 +70,6 @@ class Transducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
modified_transducer_prob: float = 0.0,
) -> torch.Tensor:
"""
Args:
@ -82,8 +81,6 @@ class Transducer(nn.Module):
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
modified_transducer_prob:
The probability to use modified transducer loss.
Returns:
Return the transducer loss.
"""

View File

@ -140,17 +140,6 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
transducer_stateless/decode.py
""",
)
parser.add_argument(
"--seed",
type=int,
@ -414,7 +403,6 @@ def compute_loss(
x=feature,
x_lens=feature_lens,
y=y,
modified_transducer_prob=params.modified_transducer_prob,
)
assert loss.requires_grad == is_training