From c74cec59e9f6d00e3a5838b4f8d4ace7e2303ad4 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 17 Aug 2022 17:18:15 +0800 Subject: [PATCH] propagate changes from #525 to other librispeech recipes (#531) * propaga changes from #525 to other librispeech recipes * refactor display_and_save_batch to utils * fixed typo * reformat code style --- .../ASR/pruned_transducer_stateless/model.py | 10 ++- .../ASR/pruned_transducer_stateless/train.py | 34 +++++++++ .../ASR/pruned_transducer_stateless2/train.py | 72 ++++++++++--------- .../ASR/pruned_transducer_stateless3/model.py | 10 ++- .../ASR/pruned_transducer_stateless3/train.py | 40 ++++++++++- .../ASR/pruned_transducer_stateless4/train.py | 40 ++++++++++- .../ASR/pruned_transducer_stateless5/train.py | 42 +++-------- .../ASR/pruned_transducer_stateless6/model.py | 10 ++- .../ASR/pruned_transducer_stateless6/train.py | 39 +++++++++- icefall/utils.py | 32 +++++++++ 10 files changed, 253 insertions(+), 76 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2f019bcdb..e2c9eb789 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -66,6 +66,7 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + reduction: str = "sum", ) -> torch.Tensor: """ Args: @@ -86,6 +87,10 @@ class Transducer(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. Returns: Return the transducer loss. @@ -95,6 +100,7 @@ class Transducer(nn.Module): lm_scale * lm_probs + am_scale * am_probs + (1-lm_scale-am_scale) * combined_probs """ + assert reduction in ("sum", "none"), reduction assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.num_axes == 2, y.num_axes @@ -136,7 +142,7 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, - reduction="sum", + reduction=reduction, return_grad=True, ) @@ -163,7 +169,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, - reduction="sum", + reduction=reduction, ) return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 33b23038c..c2e0f1f98 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -78,6 +78,7 @@ from icefall.env import get_env_info from icefall.utils import ( AttributeDict, MetricsTracker, + display_and_save_batch, measure_gradient_norms, measure_weight_norms, optim_step_and_measure_param_change, @@ -544,7 +545,36 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + reduction="none", ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If the batch contains more than 10 utterances AND + # if either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if feature.size(0) >= 10: + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + loss = params.simple_loss_scale * simple_loss + pruned_loss assert loss.requires_grad == is_training @@ -552,6 +582,10 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. info["frames"] = ( (feature_lens // params.subsampling_factor).sum().item() ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 4d290e39f..c801bd2bd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -88,7 +88,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -600,7 +606,35 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + reduction="none", ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If the batch contains more than 10 utterances AND + # if either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if feature.size(0) >= 10: + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, @@ -620,6 +654,10 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. info["frames"] = ( (feature_lens // params.subsampling_factor).sum().item() ) @@ -993,38 +1031,6 @@ def run(rank, world_size, args): cleanup_dist() -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - def scan_pessimistic_batches_for_oom( model: nn.Module, train_dl: torch.utils.data.DataLoader, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 5894361fc..ece340534 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -105,6 +105,7 @@ class Transducer(nn.Module): am_scale: float = 0.0, lm_scale: float = 0.0, warmup: float = 1.0, + reduction: str = "sum", ) -> torch.Tensor: """ Args: @@ -131,6 +132,10 @@ class Transducer(nn.Module): warmup: A value warmup >= 0 that determines which modules are active, values warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. Returns: Return the transducer loss. @@ -140,6 +145,7 @@ class Transducer(nn.Module): lm_scale * lm_probs + am_scale * am_probs + (1-lm_scale-am_scale) * combined_probs """ + assert reduction in ("sum", "none"), reduction assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.num_axes == 2, y.num_axes @@ -196,7 +202,7 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, - reduction="sum", + reduction=reduction, return_grad=True, ) @@ -229,7 +235,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, - reduction="sum", + reduction=reduction, ) return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 914b9b5eb..be12e69ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -84,7 +84,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -637,7 +643,35 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + reduction="none", ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If the batch contains more than 10 utterances AND + # if either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if feature.size(0) >= 10: + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, @@ -657,6 +691,10 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. info["frames"] = ( (feature_lens // params.subsampling_factor).sum().item() ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 325b01323..2ba28acd4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -93,7 +93,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -630,7 +636,35 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + reduction="none", ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If the batch contains more than 10 utterances AND + # if either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if feature.size(0) >= 10: + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, @@ -650,6 +684,10 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. info["frames"] = ( (feature_lens // params.subsampling_factor).sum().item() ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index b7ef288c6..cee7d2bff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -81,7 +81,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -670,7 +676,7 @@ def compute_loss( simple_loss = simple_loss[simple_loss_is_finite] pruned_loss = pruned_loss[pruned_loss_is_finite] - # If the batch contains more than 10 utterance AND + # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception if feature.size(0) >= 10: @@ -1108,38 +1114,6 @@ def run(rank, world_size, args): cleanup_dist() -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 1ed5636c8..9de0769d9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -89,6 +89,7 @@ class Transducer(nn.Module): am_scale: float = 0.0, lm_scale: float = 0.0, warmup: float = 1.0, + reduction: str = "sum", codebook_indexes: torch.Tensor = None, ) -> torch.Tensor: """ @@ -113,6 +114,10 @@ class Transducer(nn.Module): warmup: A value warmup >= 0 that determines which modules are active, values warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. codebook_indexes: codebook_indexes extracted from a teacher model. Returns: @@ -124,6 +129,7 @@ class Transducer(nn.Module): lm_scale * lm_probs + am_scale * am_probs + (1-lm_scale-am_scale) * combined_probs """ + assert reduction in ("sum", "none"), reduction assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.num_axes == 2, y.num_axes @@ -184,7 +190,7 @@ class Transducer(nn.Module): lm_only_scale=lm_scale, am_only_scale=am_scale, boundary=boundary, - reduction="sum", + reduction=reduction, return_grad=True, ) @@ -217,7 +223,7 @@ class Transducer(nn.Module): ranges=ranges, termination_symbol=blank_id, boundary=boundary, - reduction="sum", + reduction=reduction, ) return (simple_loss, pruned_loss, codebook_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index a4595211c..294fd4c52 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -93,7 +93,13 @@ from icefall.checkpoint import ( ) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -631,8 +637,35 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, warmup=warmup, + reduction="none", codebook_indexes=codebook_indexes, ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + # If the batch contains more than 10 utterances AND + # if either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if feature.size(0) >= 10: + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() # after the main warmup step, we keep pruned_loss_scale small # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, @@ -654,6 +687,10 @@ def compute_loss( with warnings.catch_warnings(): warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. info["frames"] = ( (feature_lens // params.subsampling_factor).sum().item() ) diff --git a/icefall/utils.py b/icefall/utils.py index 2b089c8d0..ad079222e 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -944,3 +944,35 @@ def tokenize_by_bpe_model( txt_with_bpe = "/".join(tokens) return txt_with_bpe + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}")