Use modified transducer loss in training.

This commit is contained in:
Fangjun Kuang 2022-01-20 11:27:29 +08:00
parent f94ff19bfe
commit 92e524ea7f
2 changed files with 33 additions and 1 deletions

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import k2
import torch
import torch.nn as nn
@ -62,6 +64,7 @@ class Transducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
modified_transducer_prob: float = 0.0,
) -> torch.Tensor:
"""
Args:
@ -73,6 +76,8 @@ class Transducer(nn.Module):
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
modified_transducer_prob:
The probability to use modified transducer loss.
Returns:
Return the transducer loss.
"""
@ -113,6 +118,16 @@ class Transducer(nn.Module):
# reference stage
import optimized_transducer
assert 0 <= modified_transducer_prob <= 1
if modified_transducer_prob == 0:
one_sym_per_frame = False
elif random.random() < modified_transducer_prob:
# random.random() returns a float in the range [0, 1)
one_sym_per_frame = True
else:
one_sym_per_frame = False
loss = optimized_transducer.transducer_loss(
logits=logits,
targets=y_padded,
@ -120,6 +135,7 @@ class Transducer(nn.Module):
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
one_sym_per_frame=one_sym_per_frame,
)
return loss

View File

@ -138,6 +138,17 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
transducer_stateless/decode.py
""",
)
return parser
@ -383,7 +394,12 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
loss = model(
x=feature,
x_lens=feature_lens,
y=y,
modified_transducer_prob=params.modified_transducer_prob,
)
assert loss.requires_grad == is_training