mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
Support modified transducer.
This commit is contained in:
parent
963ac73c27
commit
6de0a849ce
@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -119,6 +120,7 @@ class Transducer(nn.Module):
|
|||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
libri: bool = True,
|
libri: bool = True,
|
||||||
|
modified_transducer_prob: float = 0.0,
|
||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
@ -136,6 +138,8 @@ class Transducer(nn.Module):
|
|||||||
libri:
|
libri:
|
||||||
True to use the decoder and joiner for the LibriSpeech dataset.
|
True to use the decoder and joiner for the LibriSpeech dataset.
|
||||||
False to use the decoder and joiner for the GigaSpeech dataset.
|
False to use the decoder and joiner for the GigaSpeech dataset.
|
||||||
|
modified_transducer_prob:
|
||||||
|
The probability to use modified transducer loss.
|
||||||
prune_range:
|
prune_range:
|
||||||
The prune range for rnnt loss, it means how many symbols(context)
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
we are considering for each frame to compute the loss.
|
we are considering for each frame to compute the loss.
|
||||||
@ -163,6 +167,16 @@ class Transducer(nn.Module):
|
|||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
|
assert 0 <= modified_transducer_prob <= 1
|
||||||
|
|
||||||
|
if modified_transducer_prob == 0:
|
||||||
|
modified = False
|
||||||
|
elif random.random() < modified_transducer_prob:
|
||||||
|
# random.random() returns a float in the range [0, 1)
|
||||||
|
modified = True
|
||||||
|
else:
|
||||||
|
modified = False
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
row_splits = y.shape.row_splits(1)
|
row_splits = y.shape.row_splits(1)
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
@ -213,6 +227,7 @@ class Transducer(nn.Module):
|
|||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
modified=modified,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
@ -243,6 +258,7 @@ class Transducer(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
modified=modified,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -180,6 +180,17 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modified-transducer-prob",
|
||||||
|
type=float,
|
||||||
|
default=0.25,
|
||||||
|
help="""The probability to use modified transducer loss.
|
||||||
|
In modified transduer, it limits the maximum number of symbols
|
||||||
|
per frame to 1. See also the option --max-sym-per-frame in
|
||||||
|
pruned_transducer_stateless_multi_datasets/decode.py
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prune-range",
|
"--prune-range",
|
||||||
type=int,
|
type=int,
|
||||||
@ -498,6 +509,7 @@ def compute_loss(
|
|||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
libri=libri,
|
libri=libri,
|
||||||
|
modified_transducer_prob=params.modified_transducer_prob,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user