diff --git a/docs/source/recipes/Streaming-ASR/introduction.rst b/docs/source/recipes/Streaming-ASR/introduction.rst index d81156659..e1382e77d 100644 --- a/docs/source/recipes/Streaming-ASR/introduction.rst +++ b/docs/source/recipes/Streaming-ASR/introduction.rst @@ -30,8 +30,9 @@ In icefall, we implement the streaming conformer the way just like what `WeNet < See :doc:`Pruned transducer statelessX ` for more details. .. HINT:: - If you want to adapt a non-streaming conformer model to be streaming, please refer - to `this pull request `_. + If you want to modify a non-streaming conformer recipe to support both streaming and non-streaming, please refer + to `this pull request `_. After adding the code needed by streaming training, + you have to re-train it with the extra arguments metioned in the docs above to get a streaming model. Streaming Emformer diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 401b3ef3a..8f33f5b05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -50,6 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -261,6 +262,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py new file mode 120000 index 000000000..9aa06f82f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/lstmp.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 31a3a0505..a806244ff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -82,7 +82,13 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -368,6 +374,21 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--filter-uneven-sized-batch", + type=str2bool, + default=True, + help="""Whether to filter uneven-sized minibatch. + For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed `allowed_max_frames`: + `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + where `max_frames = max_duration * 1000 // frame_shift_ms`. + We set allowed_excess_duration_ratio=0.1. + """, + ) + add_model_arguments(parser) return parser @@ -420,6 +441,9 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { + "frame_shift_ms": 10.0, + # only used when params.filter_uneven_sized_batch is True + "allowed_excess_duration_ratio": 0.1, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -642,6 +666,13 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ + if params.filter_uneven_sized_batch: + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int( + max_frames * (1.0 + params.allowed_excess_duration_ratio) + ) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) diff --git a/icefall/utils.py b/icefall/utils.py index 99e51a2a9..ba0b7fe43 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1395,3 +1395,39 @@ def is_module_available(*modules: str) -> bool: import importlib return all(importlib.util.find_spec(m) is not None for m in modules) + + +def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): + """For the uneven-sized batch, the total duration after padding would possibly + cause OOM. Hence, for each batch, which is sorted descendingly by length, + we simply drop the last few shortest samples, so that the retained total frames + (after padding) would not exceed the given allow_max_frames. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + allowed_max_frames: + The allowed max number of frames in batch. + """ + features = batch["inputs"] + supervisions = batch["supervisions"] + + N, T, _ = features.size() + assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max()) + keep_num_utt = allowed_max_frames // T + + if keep_num_utt >= N: + return batch + + # Note: we assume the samples in batch is sorted descendingly by length + logging.info( + f"Filtering uneven-sized batch, original batch size is {N}, " + f"retained batch size is {keep_num_utt}." + ) + batch["inputs"] = features[:keep_num_utt] + for k, v in supervisions.items(): + assert len(v) == N, (len(v), N) + batch["supervisions"][k] = v[:keep_num_utt] + + return batch