Apply delay penalty on transducer (#654)

* add delay penalty

* fix CI

* fix CI
This commit is contained in:
Zengwei Yao 2022-11-04 16:10:09 +08:00 committed by GitHub
parent 65b85b732c
commit 3600ce1b5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 113 additions and 5 deletions

View File

@ -79,6 +79,9 @@ jobs:
pip uninstall -y protobuf pip uninstall -y protobuf
pip install --no-binary protobuf protobuf pip install --no-binary protobuf protobuf
pip install kaldifst
pip install onnxruntime
pip install -r requirements.txt pip install -r requirements.txt
- name: Install graphviz - name: Install graphviz

View File

@ -81,6 +81,7 @@ class Transducer(nn.Module):
lm_scale: float = 0.0, lm_scale: float = 0.0,
warmup: float = 1.0, warmup: float = 1.0,
reduction: str = "sum", reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -108,6 +109,11 @@ class Transducer(nn.Module):
"sum" to sum the losses over all utterances in the batch. "sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance "none" to return the loss in a 1-D tensor for each utterance
in the batch. in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns: Returns:
Return the transducer loss. Return the transducer loss.
@ -164,6 +170,7 @@ class Transducer(nn.Module):
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
reduction=reduction, reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True, return_grad=True,
) )
@ -196,6 +203,7 @@ class Transducer(nn.Module):
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction, reduction=reduction,
) )

View File

@ -318,6 +318,16 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -611,6 +621,7 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
reduction="none", reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
) )
simple_loss_is_finite = torch.isfinite(simple_loss) simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)

View File

@ -106,6 +106,7 @@ class Transducer(nn.Module):
lm_scale: float = 0.0, lm_scale: float = 0.0,
warmup: float = 1.0, warmup: float = 1.0,
reduction: str = "sum", reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -136,6 +137,11 @@ class Transducer(nn.Module):
"sum" to sum the losses over all utterances in the batch. "sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance "none" to return the loss in a 1-D tensor for each utterance
in the batch. in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns: Returns:
Return the transducer loss. Return the transducer loss.
@ -203,6 +209,7 @@ class Transducer(nn.Module):
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
reduction=reduction, reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True, return_grad=True,
) )
@ -235,6 +242,7 @@ class Transducer(nn.Module):
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction, reduction=reduction,
) )

View File

@ -341,6 +341,16 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset", help="The probability to select a batch from the GigaSpeech dataset",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -665,6 +675,7 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
reduction="none", reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
) )
simple_loss_is_finite = torch.isfinite(simple_loss) simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)

View File

@ -328,6 +328,16 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -623,6 +633,7 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
reduction="none", reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
) )
simple_loss_is_finite = torch.isfinite(simple_loss) simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)

View File

@ -23,7 +23,6 @@ To run this file, do:
python ./pruned_transducer_stateless/test_model.py python ./pruned_transducer_stateless/test_model.py
""" """
import torch
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
@ -43,8 +42,6 @@ def test_model():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}") print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def test_model_streaming(): def test_model_streaming():
@ -63,8 +60,6 @@ def test_model_streaming():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}") print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main(): def main():

View File

@ -81,6 +81,7 @@ class Transducer(nn.Module):
lm_scale: float = 0.0, lm_scale: float = 0.0,
warmup: float = 1.0, warmup: float = 1.0,
reduction: str = "sum", reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -108,6 +109,12 @@ class Transducer(nn.Module):
"sum" to sum the losses over all utterances in the batch. "sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance "none" to return the loss in a 1-D tensor for each utterance
in the batch. in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Returns: Returns:
Return the transducer loss. Return the transducer loss.
@ -164,6 +171,7 @@ class Transducer(nn.Module):
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
reduction=reduction, reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True, return_grad=True,
) )
@ -196,6 +204,7 @@ class Transducer(nn.Module):
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction, reduction=reduction,
) )

View File

@ -317,6 +317,16 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -607,6 +617,7 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
reduction="none", reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
) )
simple_loss_is_finite = torch.isfinite(simple_loss) simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)

View File

@ -106,6 +106,7 @@ class Transducer(nn.Module):
lm_scale: float = 0.0, lm_scale: float = 0.0,
warmup: float = 1.0, warmup: float = 1.0,
reduction: str = "sum", reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -136,6 +137,11 @@ class Transducer(nn.Module):
"sum" to sum the losses over all utterances in the batch. "sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance "none" to return the loss in a 1-D tensor for each utterance
in the batch. in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns: Returns:
Return the transducer loss. Return the transducer loss.
@ -203,6 +209,7 @@ class Transducer(nn.Module):
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
reduction=reduction, reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True, return_grad=True,
) )
@ -235,6 +242,7 @@ class Transducer(nn.Module):
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction, reduction=reduction,
) )

View File

@ -328,6 +328,16 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset", help="The probability to select a batch from the GigaSpeech dataset",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -645,6 +655,7 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
reduction="none", reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
) )
simple_loss_is_finite = torch.isfinite(simple_loss) simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)

View File

@ -335,6 +335,16 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -638,6 +648,7 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
reduction="none", reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
) )
simple_loss_is_finite = torch.isfinite(simple_loss) simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)

View File

@ -368,6 +368,16 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -662,6 +672,7 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
reduction="none", reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
) )
simple_loss_is_finite = torch.isfinite(simple_loss) simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)