diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index 17df09cc8..a828bead9 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -91,7 +91,7 @@ def get_parser(): def compute_fbank_wenetspeech_splits(args): subset = args.training_subset - subset = str(subset) + subset = str(subset) num_splits = args.num_splits output_dir = f"data/fbank/{subset}_split_{num_splits}" output_dir = Path(output_dir) diff --git a/egs/wenetspeech/ASR/local/display_manifest_statistics.py b/egs/wenetspeech/ASR/local/display_manifest_statistics.py index a94c2d9ab..30dc5a5ec 100644 --- a/egs/wenetspeech/ASR/local/display_manifest_statistics.py +++ b/egs/wenetspeech/ASR/local/display_manifest_statistics.py @@ -31,12 +31,11 @@ from lhotse import load_manifest def main(): paths = [ - #"./data/fbank/cuts_S.jsonl.gz", - #"./data/fbank/cuts_M.jsonl.gz", - "./data/fbank/cuts_L.jsonl.gz", - #"./data/fbank/cuts_DEV.jsonl.gz", - #"./data/fbank/cuts_TEST_NET.jsonl.gz", - #"./data/fbank/cuts_TEST_MEETING.jsonl.gz" + "./data/fbank/cuts_S.jsonl.gz", + "./data/fbank/cuts_M.jsonl.gz", + "./data/fbank/cuts_DEV.jsonl.gz", + "./data/fbank/cuts_TEST_NET.jsonl.gz", + "./data/fbank/cuts_TEST_MEETING.jsonl.gz", ] for path in paths: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 123f562e3..d2f8d85ce 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -338,7 +338,7 @@ class WenetSpeechAsrDataModule: if sampler_state_dict is not None: logging.info("Loading sampler state dict") train_dl.sampler.load_state_dict(sampler_state_dict) - + return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/beam_search.py index d3ac460c7..2e9bf3e0b 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -503,8 +503,7 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - #topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() topk_token_indexes = (topk_indexes % vocab_size).tolist() for k in range(len(topk_hyp_indexes)): diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index e6d170a45..9980559bf 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -900,7 +900,7 @@ def run(rank, world_size, args): train_dl = wenetspeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - + if not params.print_diagnostics and params.start_batch == 0: scan_pessimistic_batches_for_oom( model=model, @@ -909,7 +909,7 @@ def run(rank, world_size, args): graph_compiler=graph_compiler, params=params, ) - + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict")