diff --git a/egs/aishell/ASR/decode.sh b/egs/aishell/ASR/decode.sh index 8fb46182f..31fe95ecb 100644 --- a/egs/aishell/ASR/decode.sh +++ b/egs/aishell/ASR/decode.sh @@ -5,4 +5,5 @@ export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub -python3 seamlessm4t/decode.py --epoch 3 --exp-dir seamlessm4t/exp +python3 seamlessm4t/decode.py --epoch 5 --exp-dir seamlessm4t/exp +python3 seamlessm4t/decode.py --epoch 5 --avg 2 --exp-dir seamlessm4t/exp diff --git a/egs/aishell/ASR/run.sh b/egs/aishell/ASR/run.sh index b12b00dc8..1727dd7d0 100644 --- a/egs/aishell/ASR/run.sh +++ b/egs/aishell/ASR/run.sh @@ -5,4 +5,4 @@ pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github. export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub -torchrun --nproc-per-node 8 seamlessm4t/train2.py --use-fp16 1 --max-duration 300 --base-lr 1e-5 --exp-dir seamlessm4t/exp --start-epoch 4 +torchrun --nproc-per-node 8 seamlessm4t/train2.py --use-fp16 1 --max-duration 300 --base-lr 1e-5 --exp-dir seamlessm4t/exp --start-epoch 6 diff --git a/egs/aishell/ASR/seamlessm4t/decode.py b/egs/aishell/ASR/seamlessm4t/decode.py index 0e7779580..c4f1307d6 100755 --- a/egs/aishell/ASR/seamlessm4t/decode.py +++ b/egs/aishell/ASR/seamlessm4t/decode.py @@ -30,7 +30,7 @@ from asr_datamodule import AishellAsrDataModule #from conformer import Conformer from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model from icefall.decode import ( get_lattice, nbest_decoding, @@ -74,7 +74,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=20, + default=1, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", @@ -277,7 +277,7 @@ def save_results( enable_log = True test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: @@ -285,7 +285,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -300,7 +300,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" + errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -323,8 +323,8 @@ def main(): params = get_params() params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + params.suffix = f"epoch-{params.epoch}-avg-{params.avg} + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}") logging.info("Decoding started") logging.info(params) @@ -342,7 +342,25 @@ def main(): del model.text_encoder del model.text_encoder_frontend if params.epoch > 0: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()])