mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Apply delay penalty on transducer (#654)
* add delay penalty * fix CI * fix CI
This commit is contained in:
parent
65b85b732c
commit
3600ce1b5f
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@ -79,6 +79,9 @@ jobs:
|
|||||||
pip uninstall -y protobuf
|
pip uninstall -y protobuf
|
||||||
pip install --no-binary protobuf protobuf
|
pip install --no-binary protobuf protobuf
|
||||||
|
|
||||||
|
pip install kaldifst
|
||||||
|
pip install onnxruntime
|
||||||
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
- name: Install graphviz
|
- name: Install graphviz
|
||||||
|
@ -81,6 +81,7 @@ class Transducer(nn.Module):
|
|||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
reduction: str = "sum",
|
reduction: str = "sum",
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -108,6 +109,11 @@ class Transducer(nn.Module):
|
|||||||
"sum" to sum the losses over all utterances in the batch.
|
"sum" to sum the losses over all utterances in the batch.
|
||||||
"none" to return the loss in a 1-D tensor for each utterance
|
"none" to return the loss in a 1-D tensor for each utterance
|
||||||
in the batch.
|
in the batch.
|
||||||
|
delay_penalty:
|
||||||
|
A constant value used to penalize symbol delay, to encourage
|
||||||
|
streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -164,6 +170,7 @@ class Transducer(nn.Module):
|
|||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,6 +203,7 @@ class Transducer(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -318,6 +318,16 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value used to penalize symbol delay,
|
||||||
|
to encourage streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -611,6 +621,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
|
@ -106,6 +106,7 @@ class Transducer(nn.Module):
|
|||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
reduction: str = "sum",
|
reduction: str = "sum",
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -136,6 +137,11 @@ class Transducer(nn.Module):
|
|||||||
"sum" to sum the losses over all utterances in the batch.
|
"sum" to sum the losses over all utterances in the batch.
|
||||||
"none" to return the loss in a 1-D tensor for each utterance
|
"none" to return the loss in a 1-D tensor for each utterance
|
||||||
in the batch.
|
in the batch.
|
||||||
|
delay_penalty:
|
||||||
|
A constant value used to penalize symbol delay, to encourage
|
||||||
|
streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -203,6 +209,7 @@ class Transducer(nn.Module):
|
|||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -235,6 +242,7 @@ class Transducer(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -341,6 +341,16 @@ def get_parser():
|
|||||||
help="The probability to select a batch from the GigaSpeech dataset",
|
help="The probability to select a batch from the GigaSpeech dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value used to penalize symbol delay,
|
||||||
|
to encourage streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -665,6 +675,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
|
@ -328,6 +328,16 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value used to penalize symbol delay,
|
||||||
|
to encourage streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -623,6 +633,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
|
@ -23,7 +23,6 @@ To run this file, do:
|
|||||||
python ./pruned_transducer_stateless/test_model.py
|
python ./pruned_transducer_stateless/test_model.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
@ -43,8 +42,6 @@ def test_model():
|
|||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
print(f"Number of model parameters: {num_param}")
|
print(f"Number of model parameters: {num_param}")
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
|
||||||
torch.jit.script(model)
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_streaming():
|
def test_model_streaming():
|
||||||
@ -63,8 +60,6 @@ def test_model_streaming():
|
|||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
print(f"Number of model parameters: {num_param}")
|
print(f"Number of model parameters: {num_param}")
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
|
||||||
torch.jit.script(model)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -81,6 +81,7 @@ class Transducer(nn.Module):
|
|||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
reduction: str = "sum",
|
reduction: str = "sum",
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -108,6 +109,12 @@ class Transducer(nn.Module):
|
|||||||
"sum" to sum the losses over all utterances in the batch.
|
"sum" to sum the losses over all utterances in the batch.
|
||||||
"none" to return the loss in a 1-D tensor for each utterance
|
"none" to return the loss in a 1-D tensor for each utterance
|
||||||
in the batch.
|
in the batch.
|
||||||
|
delay_penalty:
|
||||||
|
A constant value used to penalize symbol delay, to encourage
|
||||||
|
streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||||
|
Returns:
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -164,6 +171,7 @@ class Transducer(nn.Module):
|
|||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,6 +204,7 @@ class Transducer(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -317,6 +317,16 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value used to penalize symbol delay,
|
||||||
|
to encourage streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -607,6 +617,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
|
@ -106,6 +106,7 @@ class Transducer(nn.Module):
|
|||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
reduction: str = "sum",
|
reduction: str = "sum",
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -136,6 +137,11 @@ class Transducer(nn.Module):
|
|||||||
"sum" to sum the losses over all utterances in the batch.
|
"sum" to sum the losses over all utterances in the batch.
|
||||||
"none" to return the loss in a 1-D tensor for each utterance
|
"none" to return the loss in a 1-D tensor for each utterance
|
||||||
in the batch.
|
in the batch.
|
||||||
|
delay_penalty:
|
||||||
|
A constant value used to penalize symbol delay, to encourage
|
||||||
|
streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -203,6 +209,7 @@ class Transducer(nn.Module):
|
|||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -235,6 +242,7 @@ class Transducer(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
reduction=reduction,
|
reduction=reduction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -328,6 +328,16 @@ def get_parser():
|
|||||||
help="The probability to select a batch from the GigaSpeech dataset",
|
help="The probability to select a batch from the GigaSpeech dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value used to penalize symbol delay,
|
||||||
|
to encourage streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -645,6 +655,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
|
@ -335,6 +335,16 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value used to penalize symbol delay,
|
||||||
|
to encourage streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -638,6 +648,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
|
@ -368,6 +368,16 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value used to penalize symbol delay,
|
||||||
|
to encourage streaming models to emit symbols earlier.
|
||||||
|
See https://github.com/k2-fsa/k2/issues/955 and
|
||||||
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -662,6 +672,7 @@ def compute_loss(
|
|||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
|
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
|
||||||
)
|
)
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user