mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Minor fixes for decoding.
This commit is contained in:
parent
52b3ed2920
commit
b0e4e5cf31
@ -24,15 +24,24 @@ Usage:
|
|||||||
--exp-dir ./transducer_lstm/exp \
|
--exp-dir ./transducer_lstm/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
(2) beam search
|
|
||||||
|
|
||||||
|
(2) beam search
|
||||||
./transducer_lstm/decode.py \
|
./transducer_lstm/decode.py \
|
||||||
--epoch 14 \
|
--epoch 14 \
|
||||||
--avg 7 \
|
--avg 7 \
|
||||||
--exp-dir ./transducer_lstm/exp \
|
--exp-dir ./transducer_lstm/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 8
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./transducer_lstm/decode.py \
|
||||||
|
--epoch 14 \
|
||||||
|
--avg 7 \
|
||||||
|
--exp-dir ./transducer_lstm/exp \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--beam-size 4
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -71,14 +80,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=77,
|
default=29,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=55,
|
default=13,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -112,8 +121,9 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=4,
|
||||||
help="Used only when --decoding-method is beam_search",
|
help="""Used only when --decoding-method is
|
||||||
|
beam_search or modified_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -123,7 +133,6 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
@ -348,12 +357,19 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
assert params.decoding_method in ("greedy_search", "beam_search")
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if params.decoding_method == "beam_search":
|
if "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -423,8 +439,5 @@ def main():
|
|||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -26,7 +26,7 @@ class Joiner(nn.Module):
|
|||||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, *unused
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -51,5 +51,7 @@ class Joiner(nn.Module):
|
|||||||
logit = F.relu(logit)
|
logit = F.relu(logit)
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
output = self.output_linear(logit)
|
||||||
|
if not self.training:
|
||||||
|
output = output.squeeze(2).squeeze(1)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -634,13 +634,23 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
|
try:
|
||||||
num_left = len(train_cuts)
|
num_left = len(train_cuts)
|
||||||
num_removed = num_in_total - num_left
|
num_removed = num_in_total - num_left
|
||||||
removed_percent = num_removed / num_in_total * 100
|
removed_percent = num_removed / num_in_total * 100
|
||||||
|
|
||||||
logging.info(f"Before removing short and long utterances: {num_in_total}")
|
logging.info(
|
||||||
|
f"Before removing short and long utterances: {num_in_total}"
|
||||||
|
)
|
||||||
logging.info(f"After removing short and long utterances: {num_left}")
|
logging.info(f"After removing short and long utterances: {num_left}")
|
||||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
logging.info(
|
||||||
|
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
|
||||||
|
)
|
||||||
|
except TypeError as e:
|
||||||
|
# You can ignore this error as previous versions of Lhotse work fine
|
||||||
|
# for the above code. In recent versions of Lhotse, it uses
|
||||||
|
# lazy filter, producing cutsets that don't have the __len__ method
|
||||||
|
logging.info(str(e))
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
train_dl = librispeech.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user