mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
add decoding with avg model
This commit is contained in:
parent
5f399dc780
commit
cc6432443d
@ -5,4 +5,5 @@
|
|||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
|
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
|
||||||
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
|
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
|
||||||
python3 seamlessm4t/decode.py --epoch 3 --exp-dir seamlessm4t/exp
|
python3 seamlessm4t/decode.py --epoch 5 --exp-dir seamlessm4t/exp
|
||||||
|
python3 seamlessm4t/decode.py --epoch 5 --avg 2 --exp-dir seamlessm4t/exp
|
||||||
|
@ -5,4 +5,4 @@ pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.
|
|||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/icefall
|
||||||
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
|
export PYTHONPATH=$PYTHONPATH:/lustre/fsw/sa/yuekaiz/asr/seamless_communication/src
|
||||||
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
|
export TORCH_HOME=/lustre/fsw/sa/yuekaiz/asr/hub
|
||||||
torchrun --nproc-per-node 8 seamlessm4t/train2.py --use-fp16 1 --max-duration 300 --base-lr 1e-5 --exp-dir seamlessm4t/exp --start-epoch 4
|
torchrun --nproc-per-node 8 seamlessm4t/train2.py --use-fp16 1 --max-duration 300 --base-lr 1e-5 --exp-dir seamlessm4t/exp --start-epoch 6
|
||||||
|
@ -30,7 +30,7 @@ from asr_datamodule import AishellAsrDataModule
|
|||||||
#from conformer import Conformer
|
#from conformer import Conformer
|
||||||
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint, average_checkpoints_with_averaged_model
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
@ -74,7 +74,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=1,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -277,7 +277,7 @@ def save_results(
|
|||||||
enable_log = True
|
enable_log = True
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
@ -285,7 +285,7 @@ def save_results(
|
|||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
@ -300,7 +300,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt"
|
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -323,8 +323,8 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}
|
||||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -342,7 +342,25 @@ def main():
|
|||||||
del model.text_encoder
|
del model.text_encoder
|
||||||
del model.text_encoder_frontend
|
del model.text_encoder_frontend
|
||||||
if params.epoch > 0:
|
if params.epoch > 0:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
if params.avg > 1:
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user