Support modified transducer.

This commit is contained in:
Fangjun Kuang 2022-03-12 00:55:04 +08:00
parent 963ac73c27
commit 6de0a849ce
2 changed files with 28 additions and 0 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License.
import random
from typing import Optional
import k2
@ -119,6 +120,7 @@ class Transducer(nn.Module):
x_lens: torch.Tensor,
y: k2.RaggedTensor,
libri: bool = True,
modified_transducer_prob: float = 0.0,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
@ -136,6 +138,8 @@ class Transducer(nn.Module):
libri:
True to use the decoder and joiner for the LibriSpeech dataset.
False to use the decoder and joiner for the GigaSpeech dataset.
modified_transducer_prob:
The probability to use modified transducer loss.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
@ -163,6 +167,16 @@ class Transducer(nn.Module):
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
assert 0 <= modified_transducer_prob <= 1
if modified_transducer_prob == 0:
modified = False
elif random.random() < modified_transducer_prob:
# random.random() returns a float in the range [0, 1)
modified = True
else:
modified = False
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
@ -213,6 +227,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
modified=modified,
reduction="sum",
return_grad=True,
)
@ -243,6 +258,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
modified=modified,
reduction="sum",
)

View File

@ -180,6 +180,17 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
pruned_transducer_stateless_multi_datasets/decode.py
""",
)
parser.add_argument(
"--prune-range",
type=int,
@ -498,6 +509,7 @@ def compute_loss(
x_lens=feature_lens,
y=y,
libri=libri,
modified_transducer_prob=params.modified_transducer_prob,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,