From fc7574f6d2bdb2f9c83b51148f5291aea36f9ab9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 29 Apr 2022 12:03:22 +0800 Subject: [PATCH] Add results. --- egs/librispeech/ASR/README.md | 8 +- egs/librispeech/ASR/RESULTS.md | 100 ++++++ .../pruned_transducer_stateless3/decode.py | 4 + .../pruned_transducer_stateless3/export.py | 183 +++++++++++ .../pretrained.py | 289 ++++++++++++++++++ .../ASR/pruned_transducer_stateless3/train.py | 1 - 6 files changed, 578 insertions(+), 7 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index b3e90a052..ffee8d152 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -1,9 +1,4 @@ - -# Introduction - -Please refer to -for how to run models in this recipe. - +# Introduction Please refer to for how to run models in this recipe. # Transducers There are various folders containing the name `transducer` in this folder. @@ -17,6 +12,7 @@ The following table lists the differences among them. | `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | +| `pruned_transducer_stateless3` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss + using GigaSpeech as extra training data | The decoder in `transducer_stateless` is modified from the paper diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 3488535a6..59c19715f 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,105 @@ ## Results +### LibriSpeech BPE training results (Pruned Transducer 3) + +[pruned_transducer_stateless3](./pruned_transducer_stateless3) +Same as `Pruned Transducer 2` but using the XL subset from +[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data. + +During training, it selects either a batch from GigaSpeech with prob `giga_prob` +or a batch from LibriSpeech with prob `1 - giga_prob`. All utterances within +a batch comes from the same dataset. + +See + +The WERs are: + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|----------------------------------------| +| greedy search (max sym per frame 1) | 2.21 | 5.09 | --epoch 27 --avg 2 --max-duration 600 | +| greedy search (max sym per frame 1) | 2.25 | 5.02 | --epoch 27 --avg 12 --max-duration 600 | +| modified beam search | 2.19 | 5.03 | --epoch 25 --avg 6 --max-duration 600 | +| modified beam search | 2.23 | 4.94 | --epoch 27 --avg 10 --max-duration 600 | +| beam search | 2.16 | 4.95 | --epoch 25 --avg 7 --max-duration 600 | +| fast beam search | 2.21 | 4.96 | --epoch 27 --avg 10 --max-duration 600 | +| fast beam search | 2.19 | 4.97 | --epoch 27 --avg 12 --max-duration 600 | + +The training commands are: + +```bash +./prepare.sh +./prepare_giga_speech.sh + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +./pruned_transducer_stateless3/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 0 \ + --full-libri 1 \ + --exp-dir pruned_transducer_stateless3/exp \ + --max-duration 300 \ + --use-fp16 1 \ + --lr-epochs 4 \ + --num-workers 2 \ + --giga-prob 0.8 +``` + +The tensorboard log can be found at + +(Note: The training process is killed manually at `epoch-28.pt`.) + +Pretrained models, training logs, decoding logs, and decoding results +are available at + + +Decoding commands are: + +```bash + +# greedy search +./pruned_transducer_stateless3/decode.py \ + --epoch 27 \ + --avg 2 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method greedy_search \ + --max-sym-per-frame 1 + +# modified beam search +./pruned_transducer_stateless3/decode.py \ + --epoch 25 \ + --avg 6 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --max-sym-per-frame 1 + +# beam search +./pruned_transducer_stateless3/decode.py \ + --epoch 25 \ + --avg 7 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --max-sym-per-frame 1 + +# fast beam search + +for epoch in 27; do + for avg in 10 12; do + ./pruned_transducer_stateless3/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --max-states 32 \ + --beam 8 + done +done +``` + ### LibriSpeech BPE training results (Pruned Transducer 2) [pruned_transducer_stateless2](./pruned_transducer_stateless2) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index efc3190d1..9a6b5a117 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -511,6 +511,10 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + elif params.decoding_method == "fast_beam_search_nbest_oracle": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py new file mode 100755 index 000000000..29acc7181 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned_transducer_stateless3/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless3/decode.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py new file mode 100755 index 000000000..14f72b00d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +(1) beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +You can also use `./pruned_transducer_stateless3/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless3/exp/pretrained.pt is generated by +./pruned_transducer_stateless3/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="Used only when --method is beam_search and modified_beam_search", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + if params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 037f99bc7..4966ea57f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -968,7 +968,6 @@ def run(rank, world_size, args): train_giga_cuts = gigaspeech.train_S_cuts() train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) - train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: cuts_musan = load_manifest(