diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index c9d9c4aa8..fa809b768 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -635,7 +635,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -800,7 +799,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index d08908238..60f014c48 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -872,7 +872,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 62e67530d..7c23041ca 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -1045,7 +1045,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index cbb7db086..11671db92 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -1028,7 +1028,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py index c30f6f960..057af297f 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train2.py @@ -1031,7 +1031,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 4e52f9573..3858bafd7 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -1019,7 +1019,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 74bf68ccb..8c7448d4c 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -730,7 +730,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -919,7 +918,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 47015cbe7..a354f761e 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -908,7 +908,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index e57b5c859..30154291d 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -635,7 +635,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -800,7 +799,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 45d777922..8f09f1aa5 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -999,7 +999,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index 8c8d9593b..9b67141c0 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -988,7 +988,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 4bd5b83a2..4aedeffe4 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -1019,7 +1019,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index 18cb75c37..73fcd67aa 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1074,7 +1074,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py index bc4bcf253..4c866ddd8 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -1075,7 +1075,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 99fe64793..828106f41 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -557,7 +557,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index c5a05d349..ca21bd6bf 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -953,7 +953,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 6bb37b017..23ddb6bec 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -953,7 +953,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py index 36067510c..420dc1065 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py @@ -955,7 +955,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index da1648d06..5a36ff2a9 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -43,6 +43,7 @@ from pathlib import Path from tqdm.auto import tqdm + # This function is copied from lhotse def tqdm_urlretrieve_hook(t): """Wraps tqdm instance. diff --git a/egs/librispeech/ASR/long_file_recog/beam_search.py b/egs/librispeech/ASR/long_file_recog/beam_search.py index f8c31861c..b65e9d40a 100644 --- a/egs/librispeech/ASR/long_file_recog/beam_search.py +++ b/egs/librispeech/ASR/long_file_recog/beam_search.py @@ -236,7 +236,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -507,7 +507,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] diff --git a/egs/librispeech/ASR/long_file_recog/merge_chunks.py b/egs/librispeech/ASR/long_file_recog/merge_chunks.py index d38d9c86a..9e31e00d5 100755 --- a/egs/librispeech/ASR/long_file_recog/merge_chunks.py +++ b/egs/librispeech/ASR/long_file_recog/merge_chunks.py @@ -162,7 +162,6 @@ def merge_chunks( futures = [] with ThreadPoolExecutor(max_workers=1) as executor: - for cut in cuts_chunk: cur_rec_id = cut.recording.id if len(cut_list) == 0: diff --git a/egs/librispeech/ASR/long_file_recog/recognize.py b/egs/librispeech/ASR/long_file_recog/recognize.py index 96c83f859..466253446 100755 --- a/egs/librispeech/ASR/long_file_recog/recognize.py +++ b/egs/librispeech/ASR/long_file_recog/recognize.py @@ -264,6 +264,7 @@ def decode_dataset( - timestamps of reference transcript - timestamps of predicted result """ + # Background worker to add alignemnt and save cuts to disk. def _save_worker( cuts: List[Cut], diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 76cd4e11e..9f287ce70 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -66,7 +66,6 @@ class Eve(Optimizer): weight_decay=1e-3, target_rms=0.1, ): - if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 77e06d3b7..a4899f7bd 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -811,7 +811,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 3298568a3..7fcd242fc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -719,7 +719,7 @@ def greedy_search_batch( encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1019,7 +1019,7 @@ def modified_beam_search( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1227,7 +1227,7 @@ def modified_beam_search_lm_rescore( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -1427,7 +1427,7 @@ def modified_beam_search_lm_rescore_LODR( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] @@ -2608,7 +2608,6 @@ def modified_beam_search_LODR( context_score = 0 new_context_state = None if context_graph is None else hyp.context_state if new_token not in (blank_id, unk_id): - if context_graph is not None: ( context_score, @@ -2758,7 +2757,7 @@ def modified_beam_search_lm_shallow_fusion( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] # get batch @@ -2900,7 +2899,6 @@ def modified_beam_search_lm_shallow_fusion( new_token = topk_token_indexes[k] new_timestamp = hyp.timestamp[:] if new_token not in (blank_id, unk_id): - ys.append(new_token) new_timestamp.append(t) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2d7f557ad..f54bc2709 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -66,7 +66,6 @@ class Eve(Optimizer): weight_decay=1e-3, target_rms=0.1, ): - if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 963ebdc2d..91d64c1df 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -528,7 +528,6 @@ class ScaledLSTM(nn.LSTM): return with torch.cuda.device_of(first_fw): - # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is # an inplace operation on self._flat_weights with torch.no_grad(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 3b5a635e4..66dc5f991 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -1003,7 +1003,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 14ff86f23..3bca7db2c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -56,7 +56,6 @@ class CodebookIndexExtractor: """ def __init__(self, params: AttributeDict): - self.params = params params.subsets = ["clean-100"] if self.params.full_libri: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py index 76cd56bbb..bfb5fe609 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -111,7 +111,7 @@ def batch_force_alignment( offset = 0 finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): + for t, batch_size in enumerate(batch_size_list): start = offset end = offset + batch_size current_encoder_out = encoder_out.data[start:end] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 3ee2b7d65..4e261dbc1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -1132,7 +1132,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index aa3cef338..8ab3589da 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -117,7 +117,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): parameters_names=None, show_dominant_parameters=True, ): - assert parameters_names is not None, ( "Please prepare parameters_names," "which is a List[List[str]]. Each List[str] is for a group" @@ -224,9 +223,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -325,7 +322,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -410,7 +407,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -426,7 +423,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1039,7 +1035,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2b4d51089..fac3706d2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -1028,7 +1028,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b387968a9..d8fa08372 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1052,7 +1052,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index 23fb6f497..25a1aa674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -1042,7 +1042,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 99090b2c1..2d915ff87 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -1029,7 +1029,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py index 9be629149..aa6c0668a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py @@ -1030,7 +1030,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index b494253d6..565dc7a16 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -1141,7 +1141,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index bee414292..3f271c5b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -1154,7 +1154,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index be6fabf35..0b982f4bf 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -230,7 +230,9 @@ class Conformer(Transformer): x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask ) # (T, B, F) else: - x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F) + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) if self.normalize_before: x = self.after_norm(x) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index bb55ed6bb..14d7274c2 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -543,7 +543,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 0aa1587ba..90245ed46 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -463,7 +463,6 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}" ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index f2a09346c..9ac6b7d03 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -513,7 +513,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index a6f2bd08c..92134116c 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -517,7 +517,6 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: - if tb_writer is not None: loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index e8db988f6..e77e54118 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -61,10 +61,15 @@ class Decoder(nn.Module): ) # the balancers are to avoid any drift in the magnitude of the # embeddings, which would interact badly with parameter averaging. - self.balancer = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) self.blank_id = blank_id @@ -81,10 +86,15 @@ class Decoder(nn.Module): groups=decoder_dim // 4, # group size == 4 bias=False, ) - self.balancer2 = Balancer(decoder_dim, channel_dim=-1, - min_positive=0.0, max_positive=1.0, - min_abs=0.5, max_abs=1.0, - prob=0.05) + self.balancer2 = Balancer( + decoder_dim, + channel_dim=-1, + min_positive=0.0, + max_positive=1.0, + min_abs=0.5, + max_abs=1.0, + prob=0.05, + ) def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ @@ -107,9 +117,7 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad( - embedding_out, pad=(self.context_size - 1, 0) - ) + embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index f03cc930e..dfb0a0057 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -52,12 +52,13 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) + assert encoder_out.ndim == decoder_out.ndim, ( + encoder_out.shape, + decoder_out.shape, + ) if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/zipformer/onnx_decode.py b/egs/librispeech/ASR/zipformer/onnx_decode.py index 2aca36ca9..356c2a830 100755 --- a/egs/librispeech/ASR/zipformer/onnx_decode.py +++ b/egs/librispeech/ASR/zipformer/onnx_decode.py @@ -303,7 +303,9 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): start_time = time.time() - results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table) + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) end_time = time.time() elapsed_seconds = end_time - start_time rtf = elapsed_seconds / total_duration diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index abfb2092c..c9b76526c 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -299,8 +298,8 @@ class ScaledAdam(BatchedOptimizer): # the input is groups of parameter or named parameter. for cur_group in iterable_or_groups: assert "named_params" in cur_group - name_list = [ x[0] for x in cur_group["named_params"] ] - p_list = [ x[1] for x in cur_group["named_params"] ] + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] del cur_group["named_params"] cur_group["params"] = p_list param_groups.append(cur_group) @@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -428,7 +425,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -513,7 +510,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -529,7 +526,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -667,8 +663,7 @@ class ScaledAdam(BatchedOptimizer): # We have to look at the trained model for parameters at or around the # param_max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, - (param_max_rms - param_rms) / param_rms) + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) delta = state["delta"] # the factor of (1-beta1) relates to momentum. @@ -879,7 +874,8 @@ class Eden(LRScheduler): warmup_factor = ( 1.0 if self.batch >= self.warmup_batches - else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) # else 0.5 + 0.5 * (self.batch / self.warmup_batches) ) @@ -1111,7 +1107,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/librispeech/ASR/zipformer/profile.py b/egs/librispeech/ASR/zipformer/profile.py index b460b5338..57f44a90a 100755 --- a/egs/librispeech/ASR/zipformer/profile.py +++ b/egs/librispeech/ASR/zipformer/profile.py @@ -100,17 +100,13 @@ class Model(nn.Module): self.encoder_embed = encoder_embed self.encoder_proj = encoder_proj - def forward( - self, feature: Tensor, feature_lens: Tensor - ) -> Tuple[Tensor, Tensor]: + def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]: x, x_lens = self.encoder_embed(feature, feature_lens) src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder( - x, x_lens, src_key_padding_mask - ) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) logits = self.encoder_proj(encoder_out) @@ -168,9 +164,7 @@ def main(): if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 7c98ef045..c0f1e3087 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -25,6 +25,7 @@ import math import torch.nn as nn from torch import Tensor + def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: max_value = torch.max(x, y) diff = torch.abs(x - y) @@ -55,28 +56,34 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: # for torch.jit.trace() return torch.logaddexp(x, y) + class PiecewiseLinear(object): """ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] respectively. """ + def __init__(self, *args): assert len(args) >= 1, len(args) if len(args) == 1 and isinstance(args[0], PiecewiseLinear): self.pairs = list(args[0].pairs) else: - self.pairs = [ (float(x), float(y)) for x,y in args ] - for (x,y) in self.pairs: + self.pairs = [(float(x), float(y)) for x, y in args] + for x, y in self.pairs: assert isinstance(x, (float, int)), type(x) assert isinstance(y, (float, int)), type(y) for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], (i, self.pairs[i], self.pairs[i + 1]) + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) def __str__(self): # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" def __call__(self, x): if x <= self.pairs[0][0]: @@ -93,37 +100,36 @@ class PiecewiseLinear(object): assert False def __mul__(self, alpha): - return PiecewiseLinear( - * [(x, y * alpha) for x, y in self.pairs]) + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) def __add__(self, x): if isinstance(x, (float, int)): - return PiecewiseLinear( - * [(p[0], p[1] + x) for p in self.pairs]) + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) s, x = self.get_common_basis(x) return PiecewiseLinear( - * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) def max(self, x): if isinstance(x, (float, int)): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def min(self, x): if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def __eq__(self, other): return self.pairs == other.pairs - def get_common_basis(self, - p: 'PiecewiseLinear', - include_crossings: bool = False): + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): """ Returns (self_mod, p_mod) which are equivalent piecewise linear functions to self and p, but with the same x values. @@ -135,28 +141,30 @@ class PiecewiseLinear(object): assert isinstance(p, PiecewiseLinear), type(p) # get sorted x-values without repetition. - x_vals = sorted(set([ x for x, _ in self.pairs ] + [ x for x, _ in p.pairs ])) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] + x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] if include_crossings: extra_x_vals = [] for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): # if the two lines in this subsegment potentially cross each other.. diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) # `pos`, between 0 and 1, gives the relative x position, # with 0 being x_vals[i] and 1 being x_vals[i+1]. pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - return ( PiecewiseLinear(* zip(x_vals, y_vals1)), - PiecewiseLinear(* zip(x_vals, y_vals2)) ) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) class ScheduledFloat(torch.nn.Module): @@ -176,9 +184,8 @@ class ScheduledFloat(torch.nn.Module): `default` is used when self.batch_count is not set or not in training mode or in torch.jit scripting mode. """ - def __init__(self, - *args, - default: float = 0.0): + + def __init__(self, *args, default: float = 0.0): super().__init__() # self.batch_count and self.name will be written to in the training loop. self.batch_count = None @@ -187,47 +194,55 @@ class ScheduledFloat(torch.nn.Module): self.schedule = PiecewiseLinear(*args) def extra_repr(self) -> str: - return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) def __float__(self): batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): + if ( + batch_count is None + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): return float(self.default) else: ans = self.schedule(self.batch_count) if random.random() < 0.0002: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) return ans def __add__(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, - default=self.default) + return ScheduledFloat(self.schedule + x, default=self.default) else: - return ScheduledFloat(self.schedule + x.schedule, - default=self.default+x.default) + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) def max(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), - default=self.default) + return ScheduledFloat(self.schedule.max(x), default=self.default) else: - return ScheduledFloat(self.schedule.max(x.schedule), - default=max(self.default, x.default)) + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) FloatLike = Union[float, ScheduledFloat] -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -242,6 +257,7 @@ class CutoffEstimator: p is the proportion of items that should be above the cutoff. """ + def __init__(self, p: float): self.p = p # total count of items @@ -255,7 +271,7 @@ class CutoffEstimator: """ Returns true if x is above the cutoff. """ - ans = (x > self.cutoff) + ans = x > self.cutoff self.count += 1 if ans: self.count_above += 1 @@ -263,7 +279,7 @@ class CutoffEstimator: delta_p = cur_p - self.p if (delta_p > 0) == ans: q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1-q) + self.cutoff = x * q + self.cutoff * (1 - q) return ans @@ -272,6 +288,7 @@ class SoftmaxFunction(torch.autograd.Function): Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -287,7 +304,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -306,17 +323,16 @@ def softmax(x: Tensor, dim: int): class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x @staticmethod @@ -328,15 +344,20 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -348,8 +369,14 @@ class BiasNormFunction(torch.autograd.Function): # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @staticmethod - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim @@ -357,10 +384,16 @@ class BiasNormFunction(torch.autograd.Function): ctx.channel_dim = channel_dim for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) return ans @staticmethod @@ -376,7 +409,9 @@ class BiasNormFunction(torch.autograd.Function): log_scale.requires_grad = True with torch.enable_grad(): # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales ans.backward(gradient=ans_grad) return x.grad, bias.grad.flatten(), log_scale.grad, None, None @@ -412,14 +447,15 @@ class BiasNorm(torch.nn.Module): than the input of this module to be required to be stored for the backprop. """ + def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, ) -> None: super(BiasNorm, self).__init__() self.num_channels = num_channels @@ -442,23 +478,24 @@ class BiasNorm(torch.nn.Module): bias = self.bias for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * - self.log_scale.exp()) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() return x * scales - log_scale = limit_param_value(self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training) + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) - return BiasNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -477,15 +514,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -504,15 +537,11 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv2d(*args, - initial_scale: float = 1.0, - **kwargs) -> nn.Conv2d: +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: """ Behaves like a constructor of a modified version of nn.Conv2d that gives an easy way to set the default initial parameter scale. @@ -532,9 +561,7 @@ def ScaledConv2d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans @@ -562,29 +589,36 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): Another option, if you want to do something like this, is to re-initialize the parameters. """ - def __init__(self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True): + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): super().__init__() assert kernel_size % 2 == 1 half_kernel_size = (kernel_size + 1) // 2 # will pad manually, on one side. - self.causal_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True) + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) - self.chunkwise_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias) + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) # first row is correction factors added to the scale near the left edge of the chunk, # second row is correction factors added to the scale near the right edge of the chunk, @@ -596,17 +630,15 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): self.causal_conv.weight[:] *= initial_scale self.chunkwise_conv.weight[:] *= initial_scale if bias: - torch.nn.init.uniform_(self.causal_conv.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) - def forward(self, - x: Tensor, - chunk_size: int = -1) -> Tensor: + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. """ (batch_size, num_channels, seq_len) = x.shape @@ -622,28 +654,32 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): x = torch.nn.functional.pad(x, (left_pad, right_pad)) - x_causal = self.causal_conv(x[..., :left_pad + seq_len]) + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) assert x_causal.shape == (batch_size, num_channels, seq_len) x_chunk = x[..., left_pad:] num_chunks = x_chunk.shape[2] // chunk_size x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, - num_channels, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) x_chunk = self.chunkwise_conv(x_chunk) # does not change shape chunk_scale = self._get_chunk_scale(chunk_size) x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape(batch_size, num_chunks, - num_channels, chunk_size).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] return x_chunk + x_causal def _get_chunk_scale(self, chunk_size: int): """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" + scale the output of self.chunkwise_conv.""" left_edge = self.chunkwise_conv_scale[0] right_edge = self.chunkwise_conv_scale[1] if chunk_size < self.kernel_size: @@ -652,9 +688,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): else: t = chunk_size - self.kernel_size channels = left_edge.shape[0] - pad = torch.zeros(channels, t, - device=left_edge.device, - dtype=left_edge.dtype) + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) left_edge = torch.cat((left_edge, pad), dim=-1) right_edge = torch.cat((pad, right_edge), dim=-1) return 1.0 + (left_edge + right_edge) @@ -698,14 +734,14 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): class BalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim @@ -715,10 +751,8 @@ class BalancerFunction(torch.autograd.Function): return x @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None]: - x, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config try: @@ -727,8 +761,8 @@ class BalancerFunction(torch.autograd.Function): x = x.to(torch.float32) x = x.detach() x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) mean = x.mean(dim=mean_dims, keepdim=True) stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() rms = uncentered_var.clamp(min=1.0e-20).sqrt() @@ -742,11 +776,16 @@ class BalancerFunction(torch.autograd.Function): rms_clamped = rms.clamp(min=min_rms, max=max_rms) r_loss = (rms_clamped / rms).log().abs() - loss = (m_loss + r_loss) + loss = m_loss + r_loss loss.backward(gradient=torch.ones_like(loss)) loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) loss_grad = loss_grad * (grad_scale / loss_grad_rms) @@ -757,7 +796,9 @@ class BalancerFunction(torch.autograd.Function): x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) x_grad = x_grad_mod.to(x_grad.dtype) except Exception as e: - logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None, None, None, None, None, None @@ -793,16 +834,17 @@ class Balancer(torch.nn.Module): on each forward(). This is done randomly to prevent all layers from doing it at the same time. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, ): super().__init__() @@ -823,8 +865,11 @@ class Balancer(torch.nn.Module): self.grad_scale = grad_scale def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): return _no_op(x) prob = float(self.prob) @@ -842,7 +887,7 @@ class Balancer(torch.nn.Module): eps = 1.0e-10 # eps is to prevent crashes if x is exactly 0 or 1. # we'll just end up returning a fairly large value. - return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 def _approx_inverse_erf(x): # 1 / (sqrt(pi) * ln(2)), @@ -853,6 +898,7 @@ class Balancer(torch.nn.Module): # and math.erf(0.0407316414078772) = 0.045935330944660666, # which is pretty close to 0.05. return 0.8139535143 * _atanh(x) + # first convert x from the range 0..1 to the range -1..1 which the error # function returns x = -1 + (2 * x) @@ -873,8 +919,9 @@ class Balancer(torch.nn.Module): return _no_op(x) -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, - name: str = None) -> Tensor: +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: """ Returns x unmodified, but in backprop will put a penalty for the excess of the absolute values of elements of x over the limit "limit". E.g. if @@ -910,13 +957,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -946,25 +992,22 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - module: nn.Module) -> Tensor: + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: ctx.save_for_backward(x) ctx.module = module return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors w = ctx.module try: @@ -976,8 +1019,10 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, w.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) if metric < float(w.whitening_limit): w.prob = w.min_prob @@ -986,22 +1031,27 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None except Exception as e: - logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float,float]], - grad_scale: FloatLike): + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -1033,10 +1083,9 @@ class Whiten(nn.Module): (self.min_prob, self.max_prob) = prob assert 0 < self.min_prob <= self.max_prob <= 1 self.prob = self.max_prob - self.name = None # will be set in training loop + self.name = None # will be set in training loop - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -1071,9 +1120,11 @@ class WithLoss(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device), None + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) def with_loss(x, y, name): @@ -1118,20 +1169,21 @@ class LimitParamValue(torch.autograd.Function): @staticmethod def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors + (x,) = ctx.saved_tensors # where x < ctx.min, ensure all grads are negative (this will tend to make # x more positive). - x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) return x_grad, None, None -def limit_param_value(x: Tensor, - min: float, max: float, - prob: float = 0.6, - training: bool = True): +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): # You apply this to (typically) an nn.Parameter during training to ensure that its # (elements mostly) stays within a supplied range. This is done by modifying the # gradients in backprop. @@ -1187,7 +1239,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 @@ -1197,7 +1249,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.044 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1210,12 +1264,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) + d = d * ((ceil - floor) / 255.0) + floor return y_grad * d @@ -1239,9 +1293,7 @@ class Dropout2(nn.Module): self.p = p def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, - p=float(self.p), - training=self.training) + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) class MulForDropout3(torch.autograd.Function): @@ -1259,7 +1311,7 @@ class MulForDropout3(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, ans_grad): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors x_grad = ctx.alpha * ans_grad * (ans != 0) return x_grad, None, None @@ -1286,7 +1338,7 @@ class Dropout3(nn.Module): class SwooshLFunction(torch.autograd.Function): """ - swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 """ @staticmethod @@ -1308,13 +1360,15 @@ class SwooshLFunction(torch.autograd.Function): if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = coeff ceil = 1.0 + coeff + 0.005 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1328,20 +1382,19 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. coeff = -0.08 floor = coeff ceil = 1.0 + coeff + 0.005 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 @@ -1351,19 +1404,19 @@ class SwooshL(torch.nn.Module): return k2.swoosh_l(x) # return SwooshLFunction.apply(x) + class SwooshLOnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 class SwooshRFunction(torch.autograd.Function): """ - swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - derivatives are between -0.08 and 0.92. + derivatives are between -0.08 and 0.92. """ @staticmethod @@ -1379,17 +1432,19 @@ class SwooshRFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = -0.08 ceil = 0.925 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1403,33 +1458,32 @@ class SwooshRFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.08 ceil = 0.925 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshR(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not x.requires_grad: return k2.swoosh_r_forward(x) else: return k2.swoosh_r(x) # return SwooshRFunction.apply(x) + class SwooshROnnx(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 # simple version of SwooshL that does not redefine the backprop, used in @@ -1437,7 +1491,7 @@ class SwooshROnnx(torch.nn.Module): def SwooshLForward(x: Tensor): x_offset = x - 4.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.035 @@ -1446,28 +1500,30 @@ def SwooshLForward(x: Tensor): def SwooshRForward(x: Tensor): x_offset = x - 1.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.313261687 class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int]): + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): if dropout_p != 0.0: dropout_shape = list(x.shape) if dropout_shared_dim is not None: dropout_shape[dropout_shared_dim] = 1 # else it won't be very memory efficient. - dropout_mask = ((1.0 / (1.0 - dropout_p)) * - (torch.rand(*dropout_shape, - device=x.device, dtype=x.dtype) > dropout_p)) + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) else: dropout_mask = None @@ -1476,8 +1532,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): ctx.activation = activation forward_activation_dict = { - 'SwooshL': k2.swoosh_l_forward, - 'SwooshR': k2.swoosh_r_forward + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1495,8 +1551,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): (x, weight, bias, dropout_mask) = saved forward_and_deriv_activation_dict = { - 'SwooshL': k2.swoosh_l_forward_and_deriv, - 'SwooshR': k2.swoosh_r_forward_and_deriv + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1511,8 +1567,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): in_channels = y.shape[-1] g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), - y.reshape(-1, in_channels)) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) y_deriv = torch.matmul(ans_grad, weight) bias_deriv = None if bias is None else g.sum(dim=0) x_deriv = y_deriv * func_deriv @@ -1525,71 +1580,76 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): class ActivationDropoutAndLinear(torch.nn.Module): """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = 'SwooshL', - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0): + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): super().__init__() # create a temporary module of nn.Linear that we'll steal the # weights and bias from - l = ScaledLinear(in_channels, out_channels, - bias=bias, - initial_scale=initial_scale) + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) self.weight = l.weight # register_parameter properly handles making it a parameter when l.bias # is None. I think there is some reason for doing it this way rather # than just setting it to None but I don't know what it is, maybe # something to do with exporting the module.. - self.register_parameter('bias', l.bias) + self.register_parameter("bias", l.bias) self.activation = activation self.dropout_p = dropout_p self.dropout_shared_dim = dropout_shared_dim - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting() or torch.jit.is_tracing(): - if self.activation == 'SwooshL': + if self.activation == "SwooshL": x = SwooshLForward(x) elif self.activation == "SwooshR": x = SwooshRForward(x) else: assert False, self.activation - return torch.nn.functional.linear(x, - self.weight, - self.bias) + return torch.nn.functional.linear(x, self.weight, self.bias) return ActivationDropoutAndLinearFunction.apply( - x, self.weight, self.bias, self.activation, - float(self.dropout_p), self.dropout_shared_dim) + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: @@ -1612,10 +1672,9 @@ def _test_whiten(): x.requires_grad = True - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1656,9 +1715,7 @@ def _test_balancer_sign(): def _test_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = Balancer( @@ -1685,7 +1742,7 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) # for self-test. @@ -1699,7 +1756,7 @@ def _test_swooshl_deriv(): x.requires_grad = True m = SwooshL() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1713,7 +1770,7 @@ def _test_swooshr_deriv(): x.requires_grad = True m = SwooshR() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1727,24 +1784,24 @@ def _test_softmax(): b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) def _test_piecewise_linear(): - p = PiecewiseLinear( (0, 10.0) ) + p = PiecewiseLinear((0, 10.0)) for x in [-100, 0, 100]: assert p(x) == 10.0 - p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) - for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: print("x, y = ", x, y) assert p(x) == y, (x, p(x), y) q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] pq = p.max(q) for x in x_vals: y1 = max(p(x), q(x)) @@ -1757,7 +1814,7 @@ def _test_piecewise_linear(): assert abs(y1 - y2) < 0.001 pq = p + q for x in x_vals: - y1 = p(x) + q(x) + y1 = p(x) + q(x) y2 = pq(x) assert abs(y1 - y2) < 0.001 @@ -1772,15 +1829,22 @@ def _test_activation_dropout_and_linear(): # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() # internally, messing up the random state. for dropout_p in [0.0]: - for activation in ['SwooshL', 'SwooshR']: - m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=0.5)) - m2 = ActivationDropoutAndLinear(in_channels, out_channels, - bias=bias, initial_scale=0.5, - activation=activation, - dropout_p=dropout_p) + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) with torch.no_grad(): m2.weight[:] = m1[2].weight if bias: @@ -1790,9 +1854,9 @@ def _test_activation_dropout_and_linear(): x1.requires_grad = True # TEMP. - assert torch.allclose(SwooshRFunction.apply(x1), - SwooshRForward(x1), - atol=1.0e-03) + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) x2 = x1.clone().detach() x2.requires_grad = True @@ -1805,21 +1869,24 @@ def _test_activation_dropout_and_linear(): y2 = m2(x2) y2.backward(gradient=y_grad) - print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) print("y1 = ", y1) print("y2 = ", y2) assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) print("x1.grad = ", x1.grad) print("x2.grad = ", x2.grad) def isclose(a, b): # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + # the SwooshL() implementation has a noisy gradient due to 1-byte # storage of it. assert isclose(x1.grad, x2.grad) diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 44ff392a3..904caf8af 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: ) batch_states.append(cached_embed_left_pad) - processed_lens = torch.cat( - [state_list[i][-1] for i in range(batch_size)], dim=0 - ) + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) batch_states.append(processed_lens) return batch_states @@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk( - chunks=batch_size, dim=1 - ) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 @@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv2_list[i], ] - cached_embed_left_pad_list = batch_states[-2].chunk( - chunks=batch_size, dim=0 - ) + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) @@ -380,11 +374,7 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, @@ -404,9 +394,7 @@ def streaming_forward( new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -494,9 +482,7 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search( - model=model, encoder_out=encoder_out, streams=decode_streams - ) + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens @@ -517,9 +503,7 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = unstack_states(new_states) @@ -577,9 +561,7 @@ def decode_dataset( decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = get_init_states( - model=model, batch_size=1, device=device - ) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -649,9 +631,7 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} @@ -684,8 +664,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -718,9 +697,7 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." @@ -760,9 +737,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -789,9 +766,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 6532ddccb..d16d87bac 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -107,9 +107,7 @@ class ConvNeXt(nn.Module): if layerdrop_rate != 0.0: batch_size = x.shape[0] mask = ( - torch.rand( - (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device - ) + torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate ) else: @@ -278,9 +276,7 @@ class Conv2dSubsampling(nn.Module): # many copies of this extra gradient term. self.out_whiten = Whiten( num_groups=1, - whitening_limit=ScheduledFloat( - (0.0, 4.0), (20000.0, 8.0), default=4.0 - ), + whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), prob=(0.025, 0.25), grad_scale=0.02, ) @@ -331,7 +327,7 @@ class Conv2dSubsampling(nn.Module): with warnings.catch_warnings(): warnings.simplefilter("ignore") x_lens = (x_lens - 7) // 2 - assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max()) + assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) return x, x_lens @@ -403,8 +399,8 @@ class Conv2dSubsampling(nn.Module): left_pad = self.convnext.padding[0] freq = self.out_width channels = self.layer3_channels - cached_embed_left_pad = torch.zeros( - batch_size, channels, left_pad, freq - ).to(device) + cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( + device + ) return cached_embed_left_pad diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bc3e9c1ba..7009f3346 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -808,17 +808,16 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss @@ -1166,7 +1165,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1b3ea3e0..4b50acdde 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -981,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py index a68702776..48468cfbd 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py @@ -746,7 +746,6 @@ def train_one_epoch( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(train_dl): - if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -966,7 +965,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1019,7 +1018,6 @@ def run(rank, world_size, args): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) @@ -1118,7 +1116,6 @@ def scan_pessimistic_batches_for_oom( # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _, _ = compute_loss( params=params, model=model, diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 4f2d728be..c1bbd2ee8 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -1164,7 +1164,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index 417515968..d03970265 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -915,7 +915,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py index d80e0147c..aee3972cd 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -69,7 +69,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer -from icefall import diagnostics, byte_encode, tokenize_by_CJK_char +from icefall import byte_encode, diagnostics, tokenize_by_CJK_char from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -1018,7 +1018,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py index 42e4c010a..fc3e3b2d9 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/train.py +++ b/egs/tedlium3/ASR/conformer_ctc2/train.py @@ -905,7 +905,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 9271c8438..33d03908c 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -1126,7 +1126,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index e703100a9..82bc882bd 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -886,7 +886,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 48b347b64..49977e01b 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -851,7 +851,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 8e1b12dba..931e699d9 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -985,7 +985,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 83dbfa22f..b1557dedb 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py index 5b5ac17be..a6fa46b17 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py @@ -1001,7 +1001,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index f8dd7b287..8c53972fd 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -993,7 +993,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/icefall/__init__.py b/icefall/__init__.py index 05e2b408c..b1e4313e9 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -1,12 +1,6 @@ # isort:skip_file -from . import ( - checkpoint, - decode, - dist, - env, - utils -) +from . import checkpoint, decode, dist, env, utils from .byte_utils import ( byte_decode, diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 01836df04..0b7c42c0b 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -227,7 +227,6 @@ class ContextGraph: filename: Optional[str] = "", symbol_table: Optional[Dict[int, str]] = None, ) -> "Digraph": # noqa - """Visualize a ContextGraph via graphviz. Render ContextGraph as an image via graphviz, and return the Digraph object; diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 98870684e..700dc1500 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -23,6 +23,7 @@ from typing import Optional, Tuple, List import torch from torch import Tensor, nn + class TensorDiagnosticOptions(object): """Options object for tensor diagnostics: @@ -77,11 +78,11 @@ def get_tensor_stats( elif stats_type == "abs": x = x.abs() elif stats_type == "rms": - x = x ** 2 + x = x**2 elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: - assert stats_type in [ "value", "max", "min" ] + assert stats_type in ["value", "max", "min"] sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: @@ -121,10 +122,10 @@ class TensorDiagnostic(object): self.class_name = None # will assign in accumulate() self.stats = None # we'll later assign a list to self.stats. - # It's a list of dicts, indexed by dim (i.e. by the - # axis of the tensor). The dicts, in turn, are - # indexed by `stats-type` which are strings in - # ["abs", "max", "min", "positive", "value", "rms"]. + # It's a list of dicts, indexed by dim (i.e. by the + # axis of the tensor). The dicts, in turn, are + # indexed by `stats-type` which are strings in + # ["abs", "max", "min", "positive", "value", "rms"]. # scalar_stats contains some analysis of the activations and gradients, self.scalar_stats = None @@ -139,7 +140,6 @@ class TensorDiagnostic(object): # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. - def accumulate(self, x, class_name: Optional[str] = None): """ Accumulate tensors. @@ -193,17 +193,12 @@ class TensorDiagnostic(object): done = True break if not done: - if ( - this_dim_stats[stats_type] != [] - and stats_type == "eigs" - ): + if this_dim_stats[stats_type] != [] and stats_type == "eigs": # >1 size encountered on this dim, e.g. it's a batch or time dimension, # don't accumulat "eigs" stats type, it uses too much memory this_dim_stats[stats_type] = None else: - this_dim_stats[stats_type].append( - TensorAndCount(stats, count) - ) + this_dim_stats[stats_type].append(TensorAndCount(stats, count)) def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" @@ -220,8 +215,11 @@ class TensorDiagnostic(object): for r, v in zip(rms_stats_list, value_stats_list): stddev_stats_list.append( # r.count and v.count should be the same, but we don't check this. - TensorAndCount(r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), - r.count)) + TensorAndCount( + r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20), + r.count, + ) + ) this_dim_stats["stddev"] = stddev_stats_list for stats_type, stats_list in this_dim_stats.items(): @@ -232,7 +230,6 @@ class TensorDiagnostic(object): assert stats_type == "eigs" continue - def get_count(count): return 1 if stats_type in ["max", "min"] else count @@ -250,22 +247,20 @@ class TensorDiagnostic(object): eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print( - "Error getting eigenvalues, trying another method." - ) + print("Error getting eigenvalues, trying another method.") eigs, _ = torch.eig(stats) stats = eigs.norm(dim=1).sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - if stats_type in [ "rms", "stddev" ]: + if stats_type in ["rms", "stddev"]: # we stored the square; after aggregation we need to take sqrt. stats = stats.sqrt() # if `summarize` we print percentiles of the stats; else, # we print out individual elements. - summarize = ( - len(stats_list) > 1 - ) or self.opts.dim_is_summarized(stats.numel()) + summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( + stats.numel() + ) if summarize: # usually `summarize` will be true # print out percentiles. stats = stats.sort()[0] @@ -282,15 +277,15 @@ class TensorDiagnostic(object): ans = stats.tolist() ans = ["%.2g" % x for x in ans] ans = "[" + " ".join(ans) + "]" - if stats_type in [ "value", "rms", "stddev", "eigs" ]: + if stats_type in ["value", "rms", "stddev", "eigs"]: # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() + norm = (stats**2).sum().sqrt().item() ans += f", norm={norm:.2g}" mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() + rms = (stats**2).mean().sqrt().item() ans += f", mean={mean:.3g}, rms={rms:.3g}" # OK, "ans" contains the actual stats, e.g. @@ -298,11 +293,11 @@ class TensorDiagnostic(object): sizes = [x.tensor.shape[0] for x in stats_list] size_str = ( - f"{sizes[0]}" - if len(sizes) == 1 - else f"{min(sizes)}..{max(sizes)}" + f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" + ) + maybe_class_name = ( + f" type={self.class_name}," if self.class_name is not None else "" ) - maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" print( f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) @@ -330,7 +325,6 @@ class ScalarDiagnostic(object): self.sum_gradsq = None self.sum_abs_grad = None - def accumulate_input(self, x: Tensor, class_name: Optional[str] = None): """ Called in forward pass. @@ -347,8 +341,10 @@ class ScalarDiagnostic(object): limit = 10 if len(self.saved_inputs) > limit: - print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. " - f" Will not accumulate scalar stats.") + print( + f"ERROR: forward pass called for this module over {limit} times with no backward pass. " + f" Will not accumulate scalar stats." + ) self.is_ok = False return self.saved_inputs.append(x) @@ -359,11 +355,15 @@ class ScalarDiagnostic(object): if self.is_forward_pass: self.is_forward_pass = False - last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape + last_shape = ( + "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape + ) if len(self.saved_inputs) == 0 or grad.shape != last_shape: - print(f"ERROR: shape mismatch or no forward activation present when backward " - f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" - f", shape-of-last-saved-input={last_shape}") + print( + f"ERROR: shape mismatch or no forward activation present when backward " + f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}" + f", shape-of-last-saved-input={last_shape}" + ) self.is_ok = False return @@ -384,11 +384,19 @@ class ScalarDiagnostic(object): self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side) # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1] - self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device) - self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + self.counts = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.long, device=x.device + ) + self.sum_grad = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) # sum_gradsq is for getting error bars. - self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) - self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device) + self.sum_gradsq = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) + self.sum_abs_grad = torch.zeros( + 2 * num_ticks_per_side, dtype=torch.double, device=x.device + ) # this will round down. x = (x / self.tick_scale).to(torch.long) @@ -397,20 +405,21 @@ class ScalarDiagnostic(object): self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x)) self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double)) - self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double)) + self.sum_gradsq.index_add_( + dim=0, index=x, source=(grad * grad).to(torch.double) + ) self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double)) - def print_diagnostics(self): """Print diagnostics.""" if self.is_ok is False or self.counts is None: print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}") return - counts = self.counts.to('cpu') - sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32) - sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32) - sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32) + counts = self.counts.to("cpu") + sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32) + sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32) + sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32) counts_cumsum = counts.cumsum(dim=0) counts_tot = counts_cumsum[-1] @@ -433,19 +442,22 @@ class ScalarDiagnostic(object): bin_abs_grad = torch.zeros(num_bins) bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad) - avg_grad = (bin_grad / bin_counts) + avg_grad = bin_grad / bin_counts avg_grad_stddev = (bin_gradsq / bin_counts).sqrt() - bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin + bin_boundary_counts = ( + torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin + ) bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts) # boundaries are the "x" values between the bins, e.g. corresponding to the # locations of percentiles of the distribution. num_ticks_per_side = counts.numel() // 2 bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale - bin_grad = bin_grad / (bin_counts + 1) - bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation. + bin_conf_interval = bin_gradsq.sqrt() / ( + bin_counts + 1 + ) # consider this a standard deviation. # bin_grad / bin_abs_grad will give us a sense for how important in a practical sense, # the gradients are. bin_abs_grad = bin_abs_grad / (bin_counts + 1) @@ -458,8 +470,9 @@ class ScalarDiagnostic(object): x = "[" + " ".join(x) + "]" return x - - maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" + maybe_class_name = ( + f" type={self.class_name}," if self.class_name is not None else "" + ) print( f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, " @@ -467,7 +480,6 @@ class ScalarDiagnostic(object): ) - class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -485,9 +497,8 @@ class ModelDiagnostic(object): self.opts = opts self.diagnostics = dict() - def __getitem__(self, name: str): - T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic + T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic if name not in self.diagnostics: self.diagnostics[name] = T(self.opts, name) return self.diagnostics[name] @@ -502,18 +513,19 @@ def get_class_name(module: nn.Module): ans = type(module).__name__ # we put the below in try blocks in case anyone is using a different version of these modules that # might have different member names. - if ans == 'Balancer' or ans == 'ActivationBalancer': + if ans == "Balancer" or ans == "ActivationBalancer": try: - ans += f'[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]' + ans += f"[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]" except: pass - elif ans == 'AbsValuePenalizer': + elif ans == "AbsValuePenalizer": try: - ans += f'[{module.limit}]' + ans += f"[{module.limit}]" except: pass return ans + def attach_diagnostics( model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None ) -> ModelDiagnostic: @@ -538,73 +550,85 @@ def attach_diagnostics( if name == "": name = "" - - # Setting model_diagnostic=ans and n=name below, instead of trying to # capture the variables, ensures that we use the current values. # (this matters for `name`, since the variable gets overwritten). # These closures don't really capture by value, only by # "the final value the variable got in the function" :-( - def forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): + def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] - if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output"].accumulate(_output, - class_name=get_class_name(_module)) + if isinstance(_output, Tensor) and _output.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.output"].accumulate( + _output, class_name=get_class_name(_module) + ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=get_class_name(_module)) + if o.dtype in (torch.float32, torch.float16, torch.float64): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) - def backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name - ): + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] - if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad"].accumulate(_output, - class_name=get_class_name(_module)) + if isinstance(_output, Tensor) and _output.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.grad"].accumulate( + _output, class_name=get_class_name(_module) + ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=get_class_name(_module)) - + if o.dtype in (torch.float32, torch.float16, torch.float64): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) - if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh"]: + if type(module).__name__ in [ + "Sigmoid", + "Tanh", + "ReLU", + "TanSwish", + "Swish", + "DoubleSwish", + "Swoosh", + ]: # For these specific module types, accumulate some additional diagnostics # that can help us improve the activation function. These require a lot of memory, # to save the forward activations, so limit this to some select classes. # Note: this will not work correctly for all model types. def scalar_forward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name + _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_input, tuple): - _input, = _input + (_input,) = _input assert isinstance(_input, Tensor) - _model_diagnostic[f"{_name}.scalar"].accumulate_input(_input, - class_name=get_class_name(_module)) + _model_diagnostic[f"{_name}.scalar"].accumulate_input( + _input, class_name=get_class_name(_module) + ) def scalar_backward_hook( - _module, _input, _output, _model_diagnostic=ans, _name=name + _module, _input, _output, _model_diagnostic=ans, _name=name ): if isinstance(_output, tuple): - _output, = _output + (_output,) = _output assert isinstance(_output, Tensor) _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output) module.register_forward_hook(scalar_forward_hook) module.register_backward_hook(scalar_backward_hook) - - for name, parameter in model.named_parameters(): def param_backward_hook( diff --git a/icefall/profiler.py b/icefall/profiler.py index dc76ebebc..49e138579 100644 --- a/icefall/profiler.py +++ b/icefall/profiler.py @@ -70,25 +70,17 @@ class FlopsProfiler(object): module_flop_count.append([]) if not hasattr(module, "__pre_hook_handle__"): - module.__pre_hook_handle__ = module.register_forward_pre_hook( - pre_hook - ) + module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) def post_hook(module, input, output): if module_flop_count: - module.__flops__ += sum( - [elem[1] for elem in module_flop_count[-1]] - ) + module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]]) module_flop_count.pop() if not hasattr(module, "__post_hook_handle__"): - module.__post_hook_handle__ = module.register_forward_hook( - post_hook - ) + module.__post_hook_handle__ = module.register_forward_hook(post_hook) - self.model.apply( - partial(register_module_hooks, ignore_list=ignore_list) - ) + self.model.apply(partial(register_module_hooks, ignore_list=ignore_list)) self.started = True self.func_patched = True @@ -194,9 +186,7 @@ def _prelu_flops_compute(input: Tensor, weight: Tensor): return input.numel() -def _elu_flops_compute( - input: Tensor, alpha: float = 1.0, inplace: bool = False -): +def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False): return input.numel() @@ -259,9 +249,7 @@ def _conv_flops_compute( output_dims.append(output_dim) filters_per_channel = out_channels // groups - conv_per_position_macs = ( - int(_prod(kernel_dims)) * in_channels * filters_per_channel - ) + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel active_elements_count = batch_size * int(_prod(output_dims)) overall_conv_macs = conv_per_position_macs * active_elements_count overall_conv_flops = 2 * overall_conv_macs @@ -297,7 +285,6 @@ def _conv_trans_flops_compute( output_dims = [] for idx, input_dim in enumerate(input_dims): - output_dim = ( input_dim + 2 * paddings[idx] @@ -310,9 +297,7 @@ def _conv_trans_flops_compute( dilations = dilation if type(dilation) is tuple else (dilation, dilation) filters_per_channel = out_channels // groups - conv_per_position_macs = ( - int(_prod(kernel_dims)) * in_channels * filters_per_channel - ) + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel active_elements_count = batch_size * int(_prod(input_dims)) overall_conv_macs = conv_per_position_macs * active_elements_count overall_conv_flops = 2 * overall_conv_macs @@ -389,9 +374,7 @@ def _upsample_flops_compute(input, **kwargs): else: return int(size), 0 scale_factor = kwargs.get("scale_factor", None) - assert ( - scale_factor is not None - ), "either size or scale_factor should be defined" + assert scale_factor is not None, "either size or scale_factor should be defined" flops = input.numel() if isinstance(scale_factor, tuple) and len(scale_factor) == len(input): flops * int(_prod(scale_factor)) @@ -593,12 +576,8 @@ def _patch_functionals(): F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) # swoosh functions in k2 - k2.swoosh_l_forward = wrapFunc( - k2.swoosh_l_forward, _k2_swoosh_flops_compute - ) - k2.swoosh_r_forward = wrapFunc( - k2.swoosh_r_forward, _k2_swoosh_flops_compute - ) + k2.swoosh_l_forward = wrapFunc(k2.swoosh_l_forward, _k2_swoosh_flops_compute) + k2.swoosh_r_forward = wrapFunc(k2.swoosh_r_forward, _k2_swoosh_flops_compute) k2.swoosh_l = wrapFunc(k2.swoosh_l, _k2_swoosh_flops_compute) k2.swoosh_r = wrapFunc(k2.swoosh_r, _k2_swoosh_flops_compute) @@ -612,9 +591,7 @@ def _patch_tensor_methods(): torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute) torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute) - torch.Tensor.addmm = wrapFunc( - torch.Tensor.addmm, _tensor_addmm_flops_compute - ) + torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute) torch.mul = wrapFunc(torch.mul, _mul_flops_compute) torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute) @@ -631,14 +608,10 @@ def _patch_tensor_methods(): torch.tanh = wrapFunc(torch.tanh, _tanh_flops_compute) - torch.Tensor.softmax = wrapFunc( - torch.Tensor.softmax, _softmax_flops_compute - ) + torch.Tensor.softmax = wrapFunc(torch.Tensor.softmax, _softmax_flops_compute) torch.sigmoid = wrapFunc(torch.sigmoid, _sigmoid_flops_compute) - torch.Tensor.sigmoid = wrapFunc( - torch.Tensor.sigmoid, _sigmoid_flops_compute - ) + torch.Tensor.sigmoid = wrapFunc(torch.Tensor.sigmoid, _sigmoid_flops_compute) def _reload_functionals(): @@ -732,15 +705,11 @@ def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): flops += rnn_module.hidden_size * 4 # two hadamard _product and add for C state flops += ( - rnn_module.hidden_size - + rnn_module.hidden_size - + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size ) # final hadamard flops += ( - rnn_module.hidden_size - + rnn_module.hidden_size - + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size ) return flops diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py index d51a4b76b..28b908f82 100755 --- a/icefall/rnn_lm/check-onnx-streaming.py +++ b/icefall/rnn_lm/check-onnx-streaming.py @@ -112,7 +112,6 @@ def main(): for torch_v, onnx_v in zip( (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0) ): - assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( torch_v.shape, onnx_v.shape, diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 3d206d139..0178b80bf 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -463,7 +463,6 @@ def train_one_epoch( cur_batch_idx = params.get("cur_batch_idx", 0) for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: continue cur_batch_idx = batch_idx diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py index 7150297d6..231aca7f1 100755 --- a/icefall/shared/make_kn_lm.py +++ b/icefall/shared/make_kn_lm.py @@ -225,7 +225,6 @@ class NgramCounts: for n in range(0, self.ngram_order - 1): this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): - n_star_star = 0 for w in counts_for_hist.word_to_count.keys(): n_star_star += len(counts_for_hist.word_to_context[w]) @@ -424,7 +423,6 @@ class NgramCounts: if __name__ == "__main__": - ngram_counts = NgramCounts(args.ngram_order) if args.text is None: diff --git a/icefall/transformer_lm/model.py b/icefall/transformer_lm/model.py index 79dda3168..c78cf1821 100644 --- a/icefall/transformer_lm/model.py +++ b/icefall/transformer_lm/model.py @@ -103,7 +103,6 @@ class TransformerLM(torch.nn.Module): return nll_loss def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): - bs = x.size(0) state = None diff --git a/requirements-ci.txt b/requirements-ci.txt index 2433e190b..652e2ab47 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -20,6 +20,7 @@ kaldialign==0.7.1 sentencepiece==0.1.96 tensorboard==2.8.0 typeguard==2.13.3 +black==22.3.0 multi_quantization onnx diff --git a/requirements.txt b/requirements.txt index a07f6b7c7..f0098c236 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ sentencepiece>=0.1.96 tensorboard typeguard dill +black==22.3.0