mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
minor fix, update docs, and modify the epoch number to count from 1 in the pruned_transducer_stateless4/decode.py
This commit is contained in:
parent
8bf2fef1e0
commit
22ecc567cb
@ -18,16 +18,16 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 28 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 28 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
@ -35,8 +35,8 @@ Usage:
|
|||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 28 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
@ -44,8 +44,8 @@ Usage:
|
|||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless4/decode.py \
|
||||||
--epoch 28 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 1500 \
|
||||||
@ -99,9 +99,9 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=30,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 0.
|
Note: Epoch counts from 1.
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,13 +128,17 @@ def get_parser():
|
|||||||
"--use-averaged-model",
|
"--use-averaged-model",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether to load averaged model",
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="pruned_transducer_stateless4/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -529,19 +533,20 @@ def main():
|
|||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
for i in range(start, params.epoch + 1):
|
for i in range(start, params.epoch + 1):
|
||||||
if start >= 0:
|
if i >= 1:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
assert params.iter == 0
|
assert params.iter == 0 and params.avg > 0
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
logging.info(
|
logging.info(
|
||||||
f"averaging modes over range with {filename_start} (excluded) "
|
f"Calculating the averaged model over epoch range from "
|
||||||
f"and {filename_end}"
|
f"{start} (excluded) to {params.epoch}"
|
||||||
)
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(
|
model.load_state_dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user