mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Use modified transducer loss in training.
This commit is contained in:
parent
f94ff19bfe
commit
92e524ea7f
@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -62,6 +64,7 @@ class Transducer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
|
modified_transducer_prob: float = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -73,6 +76,8 @@ class Transducer(nn.Module):
|
|||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
|
modified_transducer_prob:
|
||||||
|
The probability to use modified transducer loss.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
"""
|
"""
|
||||||
@ -113,6 +118,16 @@ class Transducer(nn.Module):
|
|||||||
# reference stage
|
# reference stage
|
||||||
import optimized_transducer
|
import optimized_transducer
|
||||||
|
|
||||||
|
assert 0 <= modified_transducer_prob <= 1
|
||||||
|
|
||||||
|
if modified_transducer_prob == 0:
|
||||||
|
one_sym_per_frame = False
|
||||||
|
elif random.random() < modified_transducer_prob:
|
||||||
|
# random.random() returns a float in the range [0, 1)
|
||||||
|
one_sym_per_frame = True
|
||||||
|
else:
|
||||||
|
one_sym_per_frame = False
|
||||||
|
|
||||||
loss = optimized_transducer.transducer_loss(
|
loss = optimized_transducer.transducer_loss(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
targets=y_padded,
|
targets=y_padded,
|
||||||
@ -120,6 +135,7 @@ class Transducer(nn.Module):
|
|||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
blank=blank_id,
|
blank=blank_id,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
|
one_sym_per_frame=one_sym_per_frame,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -138,6 +138,17 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modified-transducer-prob",
|
||||||
|
type=float,
|
||||||
|
default=0.25,
|
||||||
|
help="""The probability to use modified transducer loss.
|
||||||
|
In modified transduer, it limits the maximum number of symbols
|
||||||
|
per frame to 1. See also the option --max-sym-per-frame in
|
||||||
|
transducer_stateless/decode.py
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -383,7 +394,12 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
loss = model(
|
||||||
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
y=y,
|
||||||
|
modified_transducer_prob=params.modified_transducer_prob,
|
||||||
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user