mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Update train.py to use torchaudio's RNN-T loss.
This commit is contained in:
parent
ad69dbeedf
commit
fd6416e6c1
@ -70,7 +70,6 @@ class Transducer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
modified_transducer_prob: float = 0.0,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -82,8 +81,6 @@ class Transducer(nn.Module):
|
|||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
modified_transducer_prob:
|
|
||||||
The probability to use modified transducer loss.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
"""
|
"""
|
||||||
|
@ -140,17 +140,6 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--modified-transducer-prob",
|
|
||||||
type=float,
|
|
||||||
default=0.25,
|
|
||||||
help="""The probability to use modified transducer loss.
|
|
||||||
In modified transduer, it limits the maximum number of symbols
|
|
||||||
per frame to 1. See also the option --max-sym-per-frame in
|
|
||||||
transducer_stateless/decode.py
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -414,7 +403,6 @@ def compute_loss(
|
|||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
modified_transducer_prob=params.modified_transducer_prob,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
Loading…
x
Reference in New Issue
Block a user