diff --git a/README.md b/README.md index 707ed09d0..1b98e1571 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,18 @@ The WER for this model is: We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) + +#### RNN-T Conformer model + +Using Conformer as encoder. See [egs/librispeech/ASR/transducer][egs/librispeech/ASR/transducer]. + +The best WER we currently have is: + +| | test-clean | test-other | +|-----|------------|------------| +| WER | 3.16 | 7.71 | + + ### Aishell We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc] diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index efca20ac3..013e065be 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch -from transducer.model import Transducer +from model import Transducer def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: diff --git a/egs/librispeech/ASR/transducer/conformer.py b/egs/librispeech/ASR/transducer/conformer.py index 22977b835..245aaa428 100644 --- a/egs/librispeech/ASR/transducer/conformer.py +++ b/egs/librispeech/ASR/transducer/conformer.py @@ -22,7 +22,7 @@ from typing import Optional, Tuple import torch from torch import Tensor, nn -from transducer.transformer import Transformer +from transformer import Transformer from icefall.utils import make_pad_mask diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 2d7fcf41d..8b36634eb 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -46,11 +46,11 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from transducer.beam_search import beam_search, greedy_search -from transducer.conformer import Conformer -from transducer.decoder import Decoder -from transducer.joiner import Joiner -from transducer.model import Transducer +from beam_search import beam_search, greedy_search +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info @@ -70,14 +70,14 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=77, + default=26, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, - default=55, + default=12, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py new file mode 100755 index 000000000..27fa8974e --- /dev/null +++ b/egs/librispeech/ASR/transducer/export.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./transducer/export.py \ + --exp-dir ./transducer/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 26 \ + --avg 12 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `transducer/decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./transducer/decode.py \ + --exp-dir ./transducer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 1 \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=26, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=12, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "use_feat_batchnorm": True, + # decoder params + "decoder_embedding_dim": 1024, + "num_decoder_layers": 4, + "decoder_hidden_dim": 512, + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict): + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + return encoder + + +def get_decoder_model(params: AttributeDict): + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.decoder_embedding_dim, + blank_id=params.blank_id, + sos_id=params.sos_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.decoder_hidden_dim, + output_dim=params.encoder_out_dim, + ) + return decoder + + +def get_joiner_model(params: AttributeDict): + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict): + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index 47b38bdc9..8a4d3ca69 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -23,7 +23,7 @@ import torch import torch.nn as nn import torchaudio import torchaudio.functional -from transducer.encoder_interface import EncoderInterface +from encoder_interface import EncoderInterface from icefall.utils import add_sos diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py new file mode 100755 index 000000000..9dedfc16f --- /dev/null +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +./transducer/pretrained.py \ + --checkpoint ./transducer/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +You can also use `./transducer/exp/epoch-xx.pt`. + +Note: ./transducer/exp/pretrained.pt is generated by +./transducer/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import beam_search, greedy_search +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer +from torch.nn.utils.rnn import pad_sequence + +from icefall.env import get_env_info +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=5, + help="Used only when --method is beam_search", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "use_feat_batchnorm": True, + # decoder params + "decoder_embedding_dim": 1024, + "num_decoder_layers": 4, + "decoder_hidden_dim": 512, + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict): + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + return encoder + + +def get_decoder_model(params: AttributeDict): + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.decoder_embedding_dim, + blank_id=params.blank_id, + sos_id=params.sos_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.decoder_hidden_dim, + output_dim=params.encoder_out_dim, + ) + return decoder + + +def get_joiner_model(params: AttributeDict): + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict): + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + with torch.no_grad(): + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search(model=model, encoder_out=encoder_out_i) + elif params.method == "beam_search": + hyp = beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + + hyps.append(sp.decode(hyp).split()) + + print(hyps) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/transducer/test_conformer.py b/egs/librispeech/ASR/transducer/test_conformer.py index 98f7df78a..5d941d98a 100755 --- a/egs/librispeech/ASR/transducer/test_conformer.py +++ b/egs/librispeech/ASR/transducer/test_conformer.py @@ -23,7 +23,7 @@ To run this file, do: """ import torch -from transducer.conformer import Conformer +from conformer import Conformer def test_conformer(): diff --git a/egs/librispeech/ASR/transducer/test_decoder.py b/egs/librispeech/ASR/transducer/test_decoder.py index bf828fa67..44c6eb6db 100755 --- a/egs/librispeech/ASR/transducer/test_decoder.py +++ b/egs/librispeech/ASR/transducer/test_decoder.py @@ -23,7 +23,7 @@ To run this file, do: """ import torch -from transducer.decoder import Decoder +from decoder import Decoder def test_decoder(): diff --git a/egs/librispeech/ASR/transducer/test_joiner.py b/egs/librispeech/ASR/transducer/test_joiner.py index b187c5ac6..23948bbf6 100755 --- a/egs/librispeech/ASR/transducer/test_joiner.py +++ b/egs/librispeech/ASR/transducer/test_joiner.py @@ -24,7 +24,7 @@ To run this file, do: import torch -from transducer.joiner import Joiner +from joiner import Joiner def test_joiner(): diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index a9393004f..8591e2d8a 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -505,7 +505,7 @@ def test_layernorm_lstm_with_projection_forward(device="cpu"): assert_allclose(x.grad, x_clone.grad) -def test_lstm_forget_gate_bias(device): +def test_lstm_forget_gate_bias(device="cpu"): input_size = 2 hidden_size = 3 num_layers = 4 diff --git a/egs/librispeech/ASR/transducer/test_transducer.py b/egs/librispeech/ASR/transducer/test_transducer.py index a65843e9b..bd4f2c188 100755 --- a/egs/librispeech/ASR/transducer/test_transducer.py +++ b/egs/librispeech/ASR/transducer/test_transducer.py @@ -25,10 +25,10 @@ To run this file, do: import k2 import torch -from transducer.conformer import Conformer -from transducer.decoder import Decoder -from transducer.joiner import Joiner -from transducer.model import Transducer +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer def test_transducer(): @@ -61,6 +61,7 @@ def test_transducer(): sos_id=sos_id, num_layers=num_layers, hidden_dim=output_dim, + output_dim=output_dim, embedding_dropout=0.0, rnn_dropout=0.0, ) diff --git a/egs/librispeech/ASR/transducer/test_transformer.py b/egs/librispeech/ASR/transducer/test_transformer.py index 5e35d56a6..8f4585504 100755 --- a/egs/librispeech/ASR/transducer/test_transformer.py +++ b/egs/librispeech/ASR/transducer/test_transformer.py @@ -23,7 +23,7 @@ To run this file, do: """ import torch -from transducer.transformer import Transformer +from transformer import Transformer def test_transformer(): diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index 27842fd6c..5d0b2d33a 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -44,17 +44,17 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner from lhotse.cut import Cut from lhotse.utils import fix_random_seed +from model import Transducer from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from transducer.conformer import Conformer -from transducer.decoder import Decoder -from transducer.joiner import Joiner -from transducer.model import Transducer -from transducer.transformer import Noam +from transformer import Noam from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl @@ -92,7 +92,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=78, + default=30, help="Number of epochs to train.", ) @@ -126,7 +126,7 @@ def get_parser(): parser.add_argument( "--lr-factor", type=float, - default=5.0, + default=3.0, help="The lr_factor for Noam optimizer", ) diff --git a/egs/librispeech/ASR/transducer/transformer.py b/egs/librispeech/ASR/transducer/transformer.py index e38e9e12c..814290264 100644 --- a/egs/librispeech/ASR/transducer/transformer.py +++ b/egs/librispeech/ASR/transducer/transformer.py @@ -20,8 +20,8 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from transducer.encoder_interface import EncoderInterface -from transducer.subsampling import Conv2dSubsampling, VggSubsampling +from encoder_interface import EncoderInterface +from subsampling import Conv2dSubsampling, VggSubsampling from icefall.utils import make_pad_mask diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 941d24a56..62e9b5b12 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -93,7 +93,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=78, + default=30, help="Number of epochs to train.", ) @@ -127,7 +127,7 @@ def get_parser(): parser.add_argument( "--lr-factor", type=float, - default=5.0, + default=3.0, help="The lr_factor for Noam optimizer", )