minor fix, update docs, and modify the epoch number to count from 1 in the pruned_transducer_stateless4/decode.py

This commit is contained in:
yaozengwei 2022-05-05 19:13:45 +08:00
parent 8bf2fef1e0
commit 22ecc567cb

View File

@ -18,16 +18,16 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
@ -35,8 +35,8 @@ Usage:
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
@ -44,8 +44,8 @@ Usage:
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
@ -99,9 +99,9 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=28,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -128,13 +128,17 @@ def get_parser():
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model",
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless4/exp",
help="The experiment dir",
)
@ -529,19 +533,20 @@ def main():
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0
assert params.iter == 0 and params.avg > 0
start = params.epoch - params.avg
assert start >= 1
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"averaging modes over range with {filename_start} (excluded) "
f"and {filename_end}"
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(