mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* Begin to add RNN-T training for librispeech. * Copy files from conformer_ctc. Will edit it. * Use conformer/transformer model as encoder. * Begin to add training script. * Add training code. * Remove long utterances to avoid OOM when a large max_duraiton is used. * Begin to add decoding script. * Add decoding script. * Minor fixes. * Add beam search. * Use LSTM layers for the encoder. Need more tunings. * Use stateless decoder. * Minor fixes to make it ready for merge. * Fix README. * Update RESULT.md to include RNN-T Conformer. * Minor fixes. * Fix tests. * Minor fixes. * Minor fixes. * Fix tests.
464 lines
13 KiB
Python
Executable File
464 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
#
|
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Usage:
|
|
(1) greedy search
|
|
./transducer/decode.py \
|
|
--epoch 14 \
|
|
--avg 7 \
|
|
--exp-dir ./transducer/exp \
|
|
--max-duration 100 \
|
|
--decoding-method greedy_search
|
|
|
|
(2) beam search
|
|
./transducer/decode.py \
|
|
--epoch 14 \
|
|
--avg 7 \
|
|
--exp-dir ./transducer/exp \
|
|
--max-duration 100 \
|
|
--decoding-method beam_search \
|
|
--beam-size 8
|
|
"""
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
|
|
import sentencepiece as spm
|
|
import torch
|
|
import torch.nn as nn
|
|
from asr_datamodule import LibriSpeechAsrDataModule
|
|
from beam_search import beam_search, greedy_search
|
|
from conformer import Conformer
|
|
from decoder import Decoder
|
|
from joiner import Joiner
|
|
from model import Transducer
|
|
|
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
|
from icefall.env import get_env_info
|
|
from icefall.utils import (
|
|
AttributeDict,
|
|
setup_logger,
|
|
store_transcripts,
|
|
write_error_stats,
|
|
)
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--epoch",
|
|
type=int,
|
|
default=26,
|
|
help="It specifies the checkpoint to use for decoding."
|
|
"Note: Epoch counts from 0.",
|
|
)
|
|
parser.add_argument(
|
|
"--avg",
|
|
type=int,
|
|
default=12,
|
|
help="Number of checkpoints to average. Automatically select "
|
|
"consecutive checkpoints before the checkpoint specified by "
|
|
"'--epoch'. ",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--exp-dir",
|
|
type=str,
|
|
default="transducer/exp",
|
|
help="The experiment dir",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bpe-model",
|
|
type=str,
|
|
default="data/lang_bpe_500/bpe.model",
|
|
help="Path to the BPE model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--decoding-method",
|
|
type=str,
|
|
default="greedy_search",
|
|
help="""Possible values are:
|
|
- greedy_search
|
|
- beam_search
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--beam-size",
|
|
type=int,
|
|
default=5,
|
|
help="Used only when --decoding-method is beam_search",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def get_params() -> AttributeDict:
|
|
params = AttributeDict(
|
|
{
|
|
# parameters for conformer
|
|
"feature_dim": 80,
|
|
"encoder_out_dim": 512,
|
|
"subsampling_factor": 4,
|
|
"attention_dim": 512,
|
|
"nhead": 8,
|
|
"dim_feedforward": 2048,
|
|
"num_encoder_layers": 12,
|
|
"vgg_frontend": False,
|
|
"use_feat_batchnorm": True,
|
|
# decoder params
|
|
"decoder_embedding_dim": 1024,
|
|
"num_decoder_layers": 4,
|
|
"decoder_hidden_dim": 512,
|
|
"env_info": get_env_info(),
|
|
}
|
|
)
|
|
return params
|
|
|
|
|
|
def get_encoder_model(params: AttributeDict):
|
|
# TODO: We can add an option to switch between Conformer and Transformer
|
|
encoder = Conformer(
|
|
num_features=params.feature_dim,
|
|
output_dim=params.encoder_out_dim,
|
|
subsampling_factor=params.subsampling_factor,
|
|
d_model=params.attention_dim,
|
|
nhead=params.nhead,
|
|
dim_feedforward=params.dim_feedforward,
|
|
num_encoder_layers=params.num_encoder_layers,
|
|
vgg_frontend=params.vgg_frontend,
|
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
)
|
|
return encoder
|
|
|
|
|
|
def get_decoder_model(params: AttributeDict):
|
|
decoder = Decoder(
|
|
vocab_size=params.vocab_size,
|
|
embedding_dim=params.decoder_embedding_dim,
|
|
blank_id=params.blank_id,
|
|
sos_id=params.sos_id,
|
|
num_layers=params.num_decoder_layers,
|
|
hidden_dim=params.decoder_hidden_dim,
|
|
output_dim=params.encoder_out_dim,
|
|
)
|
|
return decoder
|
|
|
|
|
|
def get_joiner_model(params: AttributeDict):
|
|
joiner = Joiner(
|
|
input_dim=params.encoder_out_dim,
|
|
output_dim=params.vocab_size,
|
|
)
|
|
return joiner
|
|
|
|
|
|
def get_transducer_model(params: AttributeDict):
|
|
encoder = get_encoder_model(params)
|
|
decoder = get_decoder_model(params)
|
|
joiner = get_joiner_model(params)
|
|
|
|
model = Transducer(
|
|
encoder=encoder,
|
|
decoder=decoder,
|
|
joiner=joiner,
|
|
)
|
|
return model
|
|
|
|
|
|
def decode_one_batch(
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
sp: spm.SentencePieceProcessor,
|
|
batch: dict,
|
|
) -> Dict[str, List[List[str]]]:
|
|
"""Decode one batch and return the result in a dict. The dict has the
|
|
following format:
|
|
|
|
- key: It indicates the setting used for decoding. For example,
|
|
if greedy_search is used, it would be "greedy_search"
|
|
If beam search with a beam size of 7 is used, it would be
|
|
"beam_7"
|
|
- value: It contains the decoding result. `len(value)` equals to
|
|
batch size. `value[i]` is the decoding result for the i-th
|
|
utterance in the given batch.
|
|
Args:
|
|
params:
|
|
It's the return value of :func:`get_params`.
|
|
model:
|
|
The neural model.
|
|
sp:
|
|
The BPE model.
|
|
batch:
|
|
It is the return value from iterating
|
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
|
for the format of the `batch`.
|
|
Returns:
|
|
Return the decoding result. See above description for the format of
|
|
the returned dict.
|
|
"""
|
|
device = model.device
|
|
feature = batch["inputs"]
|
|
assert feature.ndim == 3
|
|
|
|
feature = feature.to(device)
|
|
# at entry, feature is (N, T, C)
|
|
|
|
supervisions = batch["supervisions"]
|
|
feature_lens = supervisions["num_frames"].to(device)
|
|
|
|
encoder_out, encoder_out_lens = model.encoder(
|
|
x=feature, x_lens=feature_lens
|
|
)
|
|
hyps = []
|
|
batch_size = encoder_out.size(0)
|
|
|
|
for i in range(batch_size):
|
|
# fmt: off
|
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
|
# fmt: on
|
|
if params.decoding_method == "greedy_search":
|
|
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
|
elif params.decoding_method == "beam_search":
|
|
hyp = beam_search(
|
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported decoding method: {params.decoding_method}"
|
|
)
|
|
hyps.append(sp.decode(hyp).split())
|
|
|
|
if params.decoding_method == "greedy_search":
|
|
return {"greedy_search": hyps}
|
|
else:
|
|
return {f"beam_{params.beam_size}": hyps}
|
|
|
|
|
|
def decode_dataset(
|
|
dl: torch.utils.data.DataLoader,
|
|
params: AttributeDict,
|
|
model: nn.Module,
|
|
sp: spm.SentencePieceProcessor,
|
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
|
"""Decode dataset.
|
|
|
|
Args:
|
|
dl:
|
|
PyTorch's dataloader containing the dataset to decode.
|
|
params:
|
|
It is returned by :func:`get_params`.
|
|
model:
|
|
The neural model.
|
|
sp:
|
|
The BPE model.
|
|
Returns:
|
|
Return a dict, whose key may be "greedy_search" if greedy search
|
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
|
Its value is a list of tuples. Each tuple contains two elements:
|
|
The first is the reference transcript, and the second is the
|
|
predicted result.
|
|
"""
|
|
num_cuts = 0
|
|
|
|
try:
|
|
num_batches = len(dl)
|
|
except TypeError:
|
|
num_batches = "?"
|
|
|
|
if params.decoding_method == "greedy_search":
|
|
log_interval = 100
|
|
else:
|
|
log_interval = 2
|
|
|
|
results = defaultdict(list)
|
|
for batch_idx, batch in enumerate(dl):
|
|
texts = batch["supervisions"]["text"]
|
|
|
|
hyps_dict = decode_one_batch(
|
|
params=params,
|
|
model=model,
|
|
sp=sp,
|
|
batch=batch,
|
|
)
|
|
|
|
for name, hyps in hyps_dict.items():
|
|
this_batch = []
|
|
assert len(hyps) == len(texts)
|
|
for hyp_words, ref_text in zip(hyps, texts):
|
|
ref_words = ref_text.split()
|
|
this_batch.append((ref_words, hyp_words))
|
|
|
|
results[name].extend(this_batch)
|
|
|
|
num_cuts += len(texts)
|
|
|
|
if batch_idx % log_interval == 0:
|
|
batch_str = f"{batch_idx}/{num_batches}"
|
|
|
|
logging.info(
|
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
|
)
|
|
return results
|
|
|
|
|
|
def save_results(
|
|
params: AttributeDict,
|
|
test_set_name: str,
|
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
|
):
|
|
test_set_wers = dict()
|
|
for key, results in results_dict.items():
|
|
recog_path = (
|
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
)
|
|
store_transcripts(filename=recog_path, texts=results)
|
|
logging.info(f"The transcripts are stored in {recog_path}")
|
|
|
|
# The following prints out WERs, per-word error statistics and aligned
|
|
# ref/hyp pairs.
|
|
errs_filename = (
|
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
|
)
|
|
with open(errs_filename, "w") as f:
|
|
wer = write_error_stats(
|
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
|
)
|
|
test_set_wers[key] = wer
|
|
|
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
|
|
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
errs_info = (
|
|
params.res_dir
|
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
)
|
|
with open(errs_info, "w") as f:
|
|
print("settings\tWER", file=f)
|
|
for key, val in test_set_wers:
|
|
print("{}\t{}".format(key, val), file=f)
|
|
|
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
|
note = "\tbest for {}".format(test_set_name)
|
|
for key, val in test_set_wers:
|
|
s += "{}\t{}{}\n".format(key, val, note)
|
|
note = ""
|
|
logging.info(s)
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = get_parser()
|
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
|
args = parser.parse_args()
|
|
args.exp_dir = Path(args.exp_dir)
|
|
|
|
params = get_params()
|
|
params.update(vars(args))
|
|
|
|
assert params.decoding_method in ("greedy_search", "beam_search")
|
|
params.res_dir = params.exp_dir / params.decoding_method
|
|
|
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
|
if params.decoding_method == "beam_search":
|
|
params.suffix += f"-beam-{params.beam_size}"
|
|
|
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
|
logging.info("Decoding started")
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
|
|
logging.info(f"Device: {device}")
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(params.bpe_model)
|
|
|
|
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
|
params.blank_id = sp.piece_to_id("<blk>")
|
|
params.sos_id = sp.piece_to_id("<sos/eos>")
|
|
params.vocab_size = sp.get_piece_size()
|
|
|
|
logging.info(params)
|
|
|
|
logging.info("About to create model")
|
|
model = get_transducer_model(params)
|
|
|
|
if params.avg == 1:
|
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
else:
|
|
start = params.epoch - params.avg + 1
|
|
filenames = []
|
|
for i in range(start, params.epoch + 1):
|
|
if start >= 0:
|
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
logging.info(f"averaging {filenames}")
|
|
model.to(device)
|
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
|
|
|
model.to(device)
|
|
model.eval()
|
|
model.device = device
|
|
|
|
num_param = sum([p.numel() for p in model.parameters()])
|
|
logging.info(f"Number of model parameters: {num_param}")
|
|
|
|
librispeech = LibriSpeechAsrDataModule(args)
|
|
|
|
test_clean_cuts = librispeech.test_clean_cuts()
|
|
test_other_cuts = librispeech.test_other_cuts()
|
|
|
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
|
|
|
test_sets = ["test-clean", "test-other"]
|
|
test_dl = [test_clean_dl, test_other_dl]
|
|
|
|
for test_set, test_dl in zip(test_sets, test_dl):
|
|
results_dict = decode_dataset(
|
|
dl=test_dl,
|
|
params=params,
|
|
model=model,
|
|
sp=sp,
|
|
)
|
|
|
|
save_results(
|
|
params=params,
|
|
test_set_name=test_set,
|
|
results_dict=results_dict,
|
|
)
|
|
|
|
logging.info("Done!")
|
|
|
|
|
|
torch.set_num_threads(1)
|
|
torch.set_num_interop_threads(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|