Minor fixes to make it ready for merge.

This commit is contained in:
Fangjun Kuang 2021-12-17 19:43:04 +08:00
parent f6a33a85c5
commit 164321c79d
16 changed files with 592 additions and 30 deletions

View File

@ -58,6 +58,18 @@ The WER for this model is:
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
#### RNN-T Conformer model
Using Conformer as encoder. See [egs/librispeech/ASR/transducer][egs/librispeech/ASR/transducer].
The best WER we currently have is:
| | test-clean | test-other |
|-----|------------|------------|
| WER | 3.16 | 7.71 |
### Aishell ### Aishell
We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc] We provide two models for this recipe: [conformer CTC model][Aishell_conformer_ctc]

View File

@ -18,7 +18,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from transducer.model import Transducer from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:

View File

@ -22,7 +22,7 @@ from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from transducer.transformer import Transformer from transformer import Transformer
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask

View File

@ -46,11 +46,11 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from transducer.beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
from transducer.conformer import Conformer from conformer import Conformer
from transducer.decoder import Decoder from decoder import Decoder
from transducer.joiner import Joiner from joiner import Joiner
from transducer.model import Transducer from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
@ -70,14 +70,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=77, default=26,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=55, default=12,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",

View File

@ -0,0 +1,250 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer/export.py \
--exp-dir ./transducer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 26 \
--avg 12
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer/decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./transducer/decode.py \
--exp-dir ./transducer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=26,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=12,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -23,7 +23,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
import torchaudio.functional import torchaudio.functional
from transducer.encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos

View File

@ -0,0 +1,299 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./transducer/pretrained.py \
--checkpoint ./transducer/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./transducer/exp/epoch-xx.pt`.
Note: ./transducer/exp/pretrained.pt is generated by
./transducer/export.py
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--beam-size",
type=int,
default=5,
help="Used only when --method is beam_search",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
hyps.append(sp.decode(hyp).split())
print(hyps)
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -23,7 +23,7 @@ To run this file, do:
""" """
import torch import torch
from transducer.conformer import Conformer from conformer import Conformer
def test_conformer(): def test_conformer():

View File

@ -23,7 +23,7 @@ To run this file, do:
""" """
import torch import torch
from transducer.decoder import Decoder from decoder import Decoder
def test_decoder(): def test_decoder():

View File

@ -24,7 +24,7 @@ To run this file, do:
import torch import torch
from transducer.joiner import Joiner from joiner import Joiner
def test_joiner(): def test_joiner():

View File

@ -505,7 +505,7 @@ def test_layernorm_lstm_with_projection_forward(device="cpu"):
assert_allclose(x.grad, x_clone.grad) assert_allclose(x.grad, x_clone.grad)
def test_lstm_forget_gate_bias(device): def test_lstm_forget_gate_bias(device="cpu"):
input_size = 2 input_size = 2
hidden_size = 3 hidden_size = 3
num_layers = 4 num_layers = 4

View File

@ -25,10 +25,10 @@ To run this file, do:
import k2 import k2
import torch import torch
from transducer.conformer import Conformer from conformer import Conformer
from transducer.decoder import Decoder from decoder import Decoder
from transducer.joiner import Joiner from joiner import Joiner
from transducer.model import Transducer from model import Transducer
def test_transducer(): def test_transducer():
@ -61,6 +61,7 @@ def test_transducer():
sos_id=sos_id, sos_id=sos_id,
num_layers=num_layers, num_layers=num_layers,
hidden_dim=output_dim, hidden_dim=output_dim,
output_dim=output_dim,
embedding_dropout=0.0, embedding_dropout=0.0,
rnn_dropout=0.0, rnn_dropout=0.0,
) )

View File

@ -23,7 +23,7 @@ To run this file, do:
""" """
import torch import torch
from transducer.transformer import Transformer from transformer import Transformer
def test_transformer(): def test_transformer():

View File

@ -44,17 +44,17 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transducer.conformer import Conformer from transformer import Noam
from transducer.decoder import Decoder
from transducer.joiner import Joiner
from transducer.model import Transducer
from transducer.transformer import Noam
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
@ -92,7 +92,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=78, default=30,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -126,7 +126,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lr-factor", "--lr-factor",
type=float, type=float,
default=5.0, default=3.0,
help="The lr_factor for Noam optimizer", help="The lr_factor for Noam optimizer",
) )

View File

@ -20,8 +20,8 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from transducer.encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from transducer.subsampling import Conv2dSubsampling, VggSubsampling from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask

View File

@ -93,7 +93,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=78, default=30,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -127,7 +127,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lr-factor", "--lr-factor",
type=float, type=float,
default=5.0, default=3.0,
help="The lr_factor for Noam optimizer", help="The lr_factor for Noam optimizer",
) )