Merge branch 'master' of https://github.com/k2-fsa/icefall into surt

This commit is contained in:
Desh Raj 2023-01-26 13:14:04 -05:00
commit bf5d574541
6 changed files with 75 additions and 3 deletions

View File

@ -30,8 +30,9 @@ In icefall, we implement the streaming conformer the way just like what `WeNet <
See :doc:`Pruned transducer statelessX <librispeech/pruned_transducer_stateless>` for more details.
.. HINT::
If you want to adapt a non-streaming conformer model to be streaming, please refer
to `this pull request <https://github.com/k2-fsa/icefall/pull/454>`_.
If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer
to `this pull request <https://github.com/k2-fsa/icefall/pull/454>`_. After adding the code needed by streaming training,
you have to re-train it with the extra arguments metioned in the docs above to get a streaming model.
Streaming Emformer

View File

@ -50,6 +50,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -261,6 +262,7 @@ def main():
model.eval()
if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/lstmp.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py

View File

@ -82,7 +82,13 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
filter_uneven_sized_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -368,6 +374,21 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--filter-uneven-sized-batch",
type=str2bool,
default=True,
help="""Whether to filter uneven-sized minibatch.
For the uneven-sized batch, the total duration after padding would possibly
cause OOM. Hence, for each batch, which is sorted descendingly by length,
we simply drop the last few shortest samples, so that the retained total frames
(after padding) would not exceed `allowed_max_frames`:
`allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
where `max_frames = max_duration * 1000 // frame_shift_ms`.
We set allowed_excess_duration_ratio=0.1.
""",
)
add_model_arguments(parser)
return parser
@ -420,6 +441,9 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"frame_shift_ms": 10.0,
# only used when params.filter_uneven_sized_batch is True
"allowed_excess_duration_ratio": 0.1,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -642,6 +666,13 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
if params.filter_uneven_sized_batch:
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(
max_frames * (1.0 + params.allowed_excess_duration_ratio)
)
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)

View File

@ -1395,3 +1395,39 @@ def is_module_available(*modules: str) -> bool:
import importlib
return all(importlib.util.find_spec(m) is not None for m in modules)
def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
"""For the uneven-sized batch, the total duration after padding would possibly
cause OOM. Hence, for each batch, which is sorted descendingly by length,
we simply drop the last few shortest samples, so that the retained total frames
(after padding) would not exceed the given allow_max_frames.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
allowed_max_frames:
The allowed max number of frames in batch.
"""
features = batch["inputs"]
supervisions = batch["supervisions"]
N, T, _ = features.size()
assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
keep_num_utt = allowed_max_frames // T
if keep_num_utt >= N:
return batch
# Note: we assume the samples in batch is sorted descendingly by length
logging.info(
f"Filtering uneven-sized batch, original batch size is {N}, "
f"retained batch size is {keep_num_utt}."
)
batch["inputs"] = features[:keep_num_utt]
for k, v in supervisions.items():
assert len(v) == N, (len(v), N)
batch["supervisions"][k] = v[:keep_num_utt]
return batch