diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py index 2c8249cb9..30c66ac39 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py @@ -273,8 +273,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -353,9 +352,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens - ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -415,10 +412,7 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif ( - params.decoding_method == "greedy_search" - and params.max_sym_per_frame == 1 - ): + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -543,9 +537,7 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") return results @@ -578,8 +570,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -630,9 +621,7 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += ( - f"-{params.decoding_method}-beam-size-{params.beam_size}" - ) + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -660,9 +649,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -689,9 +678,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints( - params.exp_dir, iteration=-params.iter - )[: params.avg + 1] + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -750,9 +739,7 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: # word_table = None - decoding_graph = k2.trivial_graph( - params.vocab_size - 1, device=device - ) + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None # word_table = None @@ -760,89 +747,18 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - # Note: Please use "pip install webdataset==0.1.103" - # for installing the webdataset. - import glob - import os - - from lhotse import CutSet - from lhotse.dataset.webdataset import export_to_webdataset - # we need cut ids to display recognition results. args.return_cuts = True wenetspeech = WenetSpeechAsrDataModule(args) - dev = "dev" - test_net = "test_net" - test_meeting = "test_meeting" + dev_cuts = wenetspeech.valid_cuts() + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) - if not os.path.exists(f"{dev}/shared-0.tar"): - os.makedirs(dev, exist_ok=True) - dev_cuts = wenetspeech.valid_cuts() - export_to_webdataset( - dev_cuts, - output_path=f"{dev}/shared-%d.tar", - shard_size=300, - ) + test_net_cuts = wenetspeech.test_net_cuts() + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - if not os.path.exists(f"{test_net}/shared-0.tar"): - os.makedirs(test_net, exist_ok=True) - test_net_cuts = wenetspeech.test_net_cuts() - export_to_webdataset( - test_net_cuts, - output_path=f"{test_net}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test_meeting}/shared-0.tar"): - os.makedirs(test_meeting, exist_ok=True) - test_meeting_cuts = wenetspeech.test_meeting_cuts() - export_to_webdataset( - test_meeting_cuts, - output_path=f"{test_meeting}/shared-%d.tar", - shard_size=300, - ) - - print("done") - - dev_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) - ] - cuts_dev_webdataset = CutSet.from_webdataset( - dev_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_net_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) - ] - cuts_test_net_webdataset = CutSet.from_webdataset( - test_net_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - test_meeting_shards = [ - str(path) - for path in sorted( - glob.glob(os.path.join(test_meeting, "shared-*.tar")) - ) - ] - cuts_test_meeting_webdataset = CutSet.from_webdataset( - test_meeting_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) - - dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) - test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) - test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset) + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_dl = [dev_dl, test_net_dl, test_meeting_dl] diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py index 3d736cdd4..84dca1028 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py @@ -83,9 +83,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[ - torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler -] +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -271,8 +269,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -295,8 +292,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -645,11 +641,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -697,9 +689,7 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -862,9 +852,7 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or ( - cur_grad_scale < 8.0 and batch_idx % 400 == 0 - ): + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -882,11 +870,7 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + ( - f"grad_scale: {scaler._scale.item()}" - if params.use_fp16 - else "" - ) + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) if tb_writer is not None: @@ -897,9 +881,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -907,10 +889,7 @@ def train_one_epoch( params.batch_idx_train, ) - if ( - batch_idx % params.valid_interval == 0 - and not params.print_diagnostics - ): + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1001,8 +980,15 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) optimizer = ScaledAdam( - model.parameters(), lr=params.base_lr, clipping_scale=2.0 + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1021,7 +1007,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts)