Support modified transducer.

This commit is contained in:
Fangjun Kuang 2022-03-12 00:55:04 +08:00
parent 963ac73c27
commit 6de0a849ce
2 changed files with 28 additions and 0 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import random
from typing import Optional from typing import Optional
import k2 import k2
@ -119,6 +120,7 @@ class Transducer(nn.Module):
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
libri: bool = True, libri: bool = True,
modified_transducer_prob: float = 0.0,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
@ -136,6 +138,8 @@ class Transducer(nn.Module):
libri: libri:
True to use the decoder and joiner for the LibriSpeech dataset. True to use the decoder and joiner for the LibriSpeech dataset.
False to use the decoder and joiner for the GigaSpeech dataset. False to use the decoder and joiner for the GigaSpeech dataset.
modified_transducer_prob:
The probability to use modified transducer loss.
prune_range: prune_range:
The prune range for rnnt loss, it means how many symbols(context) The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss. we are considering for each frame to compute the loss.
@ -163,6 +167,16 @@ class Transducer(nn.Module):
encoder_out, x_lens = self.encoder(x, x_lens) encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
assert 0 <= modified_transducer_prob <= 1
if modified_transducer_prob == 0:
modified = False
elif random.random() < modified_transducer_prob:
# random.random() returns a float in the range [0, 1)
modified = True
else:
modified = False
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1) row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
@ -213,6 +227,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale, lm_only_scale=lm_scale,
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
modified=modified,
reduction="sum", reduction="sum",
return_grad=True, return_grad=True,
) )
@ -243,6 +258,7 @@ class Transducer(nn.Module):
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
modified=modified,
reduction="sum", reduction="sum",
) )

View File

@ -180,6 +180,17 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
pruned_transducer_stateless_multi_datasets/decode.py
""",
)
parser.add_argument( parser.add_argument(
"--prune-range", "--prune-range",
type=int, type=int,
@ -498,6 +509,7 @@ def compute_loss(
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
libri=libri, libri=libri,
modified_transducer_prob=params.modified_transducer_prob,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,