mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Apply delay penalty on transducer (#654)
* add delay penalty * fix CI * fix CI
This commit is contained in:
parent
65b85b732c
commit
3600ce1b5f
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@ -79,6 +79,9 @@ jobs:
|
||||
pip uninstall -y protobuf
|
||||
pip install --no-binary protobuf protobuf
|
||||
|
||||
pip install kaldifst
|
||||
pip install onnxruntime
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Install graphviz
|
||||
|
@ -81,6 +81,7 @@ class Transducer(nn.Module):
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
delay_penalty: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -108,6 +109,11 @@ class Transducer(nn.Module):
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
delay_penalty:
|
||||
A constant value used to penalize symbol delay, to encourage
|
||||
streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -164,6 +170,7 @@ class Transducer(nn.Module):
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
delay_penalty=delay_penalty,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
@ -196,6 +203,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
delay_penalty=delay_penalty,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
|
@ -318,6 +318,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value used to penalize symbol delay,
|
||||
to encourage streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -611,6 +621,7 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
|
@ -106,6 +106,7 @@ class Transducer(nn.Module):
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
delay_penalty: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -136,6 +137,11 @@ class Transducer(nn.Module):
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
delay_penalty:
|
||||
A constant value used to penalize symbol delay, to encourage
|
||||
streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -203,6 +209,7 @@ class Transducer(nn.Module):
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
delay_penalty=delay_penalty,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
@ -235,6 +242,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
delay_penalty=delay_penalty,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
|
@ -341,6 +341,16 @@ def get_parser():
|
||||
help="The probability to select a batch from the GigaSpeech dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value used to penalize symbol delay,
|
||||
to encourage streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -665,6 +675,7 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
|
@ -328,6 +328,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value used to penalize symbol delay,
|
||||
to encourage streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -623,6 +633,7 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
|
@ -23,7 +23,6 @@ To run this file, do:
|
||||
python ./pruned_transducer_stateless/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
@ -43,8 +42,6 @@ def test_model():
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def test_model_streaming():
|
||||
@ -63,8 +60,6 @@ def test_model_streaming():
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -81,6 +81,7 @@ class Transducer(nn.Module):
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
delay_penalty: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -108,6 +109,12 @@ class Transducer(nn.Module):
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
delay_penalty:
|
||||
A constant value used to penalize symbol delay, to encourage
|
||||
streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||
Returns:
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -164,6 +171,7 @@ class Transducer(nn.Module):
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
delay_penalty=delay_penalty,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
@ -196,6 +204,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
delay_penalty=delay_penalty,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
|
@ -317,6 +317,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value used to penalize symbol delay,
|
||||
to encourage streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -607,6 +617,7 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
|
@ -106,6 +106,7 @@ class Transducer(nn.Module):
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
reduction: str = "sum",
|
||||
delay_penalty: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -136,6 +137,11 @@ class Transducer(nn.Module):
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
delay_penalty:
|
||||
A constant value used to penalize symbol delay, to encourage
|
||||
streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
@ -203,6 +209,7 @@ class Transducer(nn.Module):
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
delay_penalty=delay_penalty,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
@ -235,6 +242,7 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
delay_penalty=delay_penalty,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
|
@ -328,6 +328,16 @@ def get_parser():
|
||||
help="The probability to select a batch from the GigaSpeech dataset",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value used to penalize symbol delay,
|
||||
to encourage streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser
|
||||
|
||||
@ -645,6 +655,7 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
|
@ -335,6 +335,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value used to penalize symbol delay,
|
||||
to encourage streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -638,6 +648,7 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
|
@ -368,6 +368,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value used to penalize symbol delay,
|
||||
to encourage streaming models to emit symbols earlier.
|
||||
See https://github.com/k2-fsa/k2/issues/955 and
|
||||
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -662,6 +672,7 @@ def compute_loss(
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
reduction="none",
|
||||
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||
)
|
||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||
|
Loading…
x
Reference in New Issue
Block a user