minor fix, update docs, and modify the epoch number to count from 1 in the pruned_transducer_stateless4/decode.py

This commit is contained in:
yaozengwei 2022-05-05 19:13:45 +08:00
parent 8bf2fef1e0
commit 22ecc567cb

View File

@ -18,16 +18,16 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 28 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 28 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
@ -35,8 +35,8 @@ Usage:
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 28 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
@ -44,8 +44,8 @@ Usage:
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless4/decode.py \
--epoch 28 \ --epoch 30 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \ --max-duration 1500 \
@ -99,9 +99,9 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=28, default=30,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -128,13 +128,17 @@ def get_parser():
"--use-averaged-model", "--use-averaged-model",
type=str2bool, type=str2bool,
default=False, default=False,
help="Whether to load averaged model", help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
) )
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless4/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -529,19 +533,20 @@ def main():
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
filenames = [] filenames = []
for i in range(start, params.epoch + 1): for i in range(start, params.epoch + 1):
if start >= 0: if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else: else:
assert params.iter == 0 assert params.iter == 0 and params.avg > 0
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info( logging.info(
f"averaging modes over range with {filename_start} (excluded) " f"Calculating the averaged model over epoch range from "
f"and {filename_end}" f"{start} (excluded) to {params.epoch}"
) )
model.to(device) model.to(device)
model.load_state_dict( model.load_state_dict(