mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 06:04:18 +00:00
Support modified transducer.
This commit is contained in:
parent
963ac73c27
commit
6de0a849ce
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import k2
|
||||
@ -119,6 +120,7 @@ class Transducer(nn.Module):
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
libri: bool = True,
|
||||
modified_transducer_prob: float = 0.0,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
@ -136,6 +138,8 @@ class Transducer(nn.Module):
|
||||
libri:
|
||||
True to use the decoder and joiner for the LibriSpeech dataset.
|
||||
False to use the decoder and joiner for the GigaSpeech dataset.
|
||||
modified_transducer_prob:
|
||||
The probability to use modified transducer loss.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
@ -163,6 +167,16 @@ class Transducer(nn.Module):
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
assert 0 <= modified_transducer_prob <= 1
|
||||
|
||||
if modified_transducer_prob == 0:
|
||||
modified = False
|
||||
elif random.random() < modified_transducer_prob:
|
||||
# random.random() returns a float in the range [0, 1)
|
||||
modified = True
|
||||
else:
|
||||
modified = False
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
@ -213,6 +227,7 @@ class Transducer(nn.Module):
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
modified=modified,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
@ -243,6 +258,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
modified=modified,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
|
@ -180,6 +180,17 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modified-transducer-prob",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""The probability to use modified transducer loss.
|
||||
In modified transduer, it limits the maximum number of symbols
|
||||
per frame to 1. See also the option --max-sym-per-frame in
|
||||
pruned_transducer_stateless_multi_datasets/decode.py
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prune-range",
|
||||
type=int,
|
||||
@ -498,6 +509,7 @@ def compute_loss(
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
libri=libri,
|
||||
modified_transducer_prob=params.modified_transducer_prob,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
|
Loading…
x
Reference in New Issue
Block a user