diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 465f8ce85..dd27e1f35 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,5 +1,37 @@ ## Results +### Aishell training results (Transducer-stateless) +#### 2021-12-29 +(Pingfeng Luo) : The tensorboard log for training is available at + +||test| +|--|--| +|CER| 5.7% | + +You can use the following commands to reproduce our results: + +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8" +./transducer_stateless/train.py \ + --bucketing-sampler True \ + --world-size 8 \ + --lang-dir data/lang_char \ + --num-epochs 40 \ + --start-epoch 0 \ + --exp-dir transducer_stateless/exp_char \ + --max-duration 160 \ + --lr-factor 3 + +./transducer_stateless/decode.py \ + --epoch 39 \ + --avg 10 \ + --lang-dir data/lang_char \ + --exp-dir transducer_stateless/exp_char \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 +``` + ### Aishell training results (Conformer-MMI) #### 2021-12-04 (Pingfeng Luo): Result of diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index dc593eeb9..c38c4c65f 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -538,9 +538,13 @@ def main(): logging.info(f"Number of model parameters: {num_param}") aishell = AishellAsrDataModule(args) + test_cuts = aishell.test_cuts() + test_dl = aishell.test_dataloaders(test_cuts) test_sets = ["test"] - for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()): + test_dls = [test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py index 720ed6c22..542fb0364 100644 --- a/egs/aishell/ASR/conformer_ctc/subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/subsampling.py @@ -22,8 +22,8 @@ import torch.nn as nn class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). - Convert an input of shape [N, T, idim] to an output - with shape [N, T', odim], where + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 It is based on @@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module): """ Args: idim: - Input dim. The input shape is [N, T, idim]. + Input dim. The input shape is (N, T, idim). Caution: It requires: T >=7, idim >=7 odim: - Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) """ assert idim >= 7 super().__init__() @@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module): Args: x: - Its shape is [N, T, idim]. + Its shape is (N, T, idim). Returns: - Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) """ - # On entry, x is [N, T, idim] - x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape [N, ((T-1)//2 - 1))//2, odim] + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) return x @@ -80,8 +80,8 @@ class VggSubsampling(nn.Module): This paper is not 100% explicit so I am guessing to some extent, and trying to compare with other VGG implementations. - Convert an input of shape [N, T, idim] to an output - with shape [N, T', odim], where + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 """ @@ -93,10 +93,10 @@ class VggSubsampling(nn.Module): Args: idim: - Input dim. The input shape is [N, T, idim]. + Input dim. The input shape is (N, T, idim). Caution: It requires: T >=7, idim >=7 odim: - Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) """ super().__init__() @@ -149,10 +149,10 @@ class VggSubsampling(nn.Module): Args: x: - Its shape is [N, T, idim]. + Its shape is (N, T, idim). Returns: - Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) """ x = x.unsqueeze(1) x = self.layers(x) diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index 629d7a373..a4bc8e3bb 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -614,8 +614,8 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders() - valid_dl = aishell.valid_dataloaders() + train_dl = aishell.train_dataloaders(aishell.train_cuts()) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py index 1d0b3daad..35a7d98fc 100755 --- a/egs/aishell/ASR/conformer_mmi/decode.py +++ b/egs/aishell/ASR/conformer_mmi/decode.py @@ -557,9 +557,13 @@ def main(): logging.info(f"Number of model parameters: {num_param}") aishell = AishellAsrDataModule(args) + test_cuts = aishell.test_cuts() + test_dl = aishell.test_dataloaders(test_cuts) test_sets = ["test"] - for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()): + test_dls = [test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py index 14ddaf5fd..79c16d1cc 100755 --- a/egs/aishell/ASR/conformer_mmi/train.py +++ b/egs/aishell/ASR/conformer_mmi/train.py @@ -608,8 +608,9 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders() - valid_dl = aishell.valid_dataloaders() + train_cuts = aishell.train_cuts() + train_dl = aishell.train_dataloaders(train_cuts) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) diff --git a/egs/aishell/ASR/local/make_syllable_lexicon.py b/egs/aishell/ASR/local/make_syllable_lexicon.py new file mode 100755 index 000000000..15c0f8ac0 --- /dev/null +++ b/egs/aishell/ASR/local/make_syllable_lexicon.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +# Copyright 2021 (Author: Pingfeng Luo) +""" + make syllables lexicon and handle heteronym +""" +import argparse +from pathlib import Path +from pypinyin import pinyin, lazy_pinyin, Style + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + return parser.parse_args() + + +def process_line( + line: str +) -> None: + """ + Args: + line: + A line of transcript consisting of space(s) separated word and phones + input : + 你好 n i3 h ao3 + 晴天 q ing2 t ian1 + + output : + 你好 ni3 hao3 + 晴天 qing2 tian1 + Returns: + Return None. + """ + chars = line.strip().split()[0] + pinyins = pinyin(chars, style=Style.TONE3, heteronym=True) + word_syllables = [] + word_syllables_num = 1 + inited = False + for char_syllables in pinyins : + new_char_syllables_num = len(char_syllables) + if not inited and len(char_syllables) : + word_syllables = [char_syllables[0]] + inited = True + elif new_char_syllables_num == 1 : + for i in range(word_syllables_num) : + word_syllables[i] += " " + str(char_syllables) + elif new_char_syllables_num > 1 : + word_syllables = word_syllables * new_char_syllables_num + for pre_index in range(word_syllables_num) : + for expand_index in range(new_char_syllables_num) : + word_syllables[pre_index * new_char_syllables_num + expand_index] += " " + char_syllables[expand_index] + word_syllables_num *= new_char_syllables_num + + for word_syallable in word_syllables : + print("{} {}".format(chars.strip(), str(word_syallable).strip())) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + + with open(args.lexicon) as f: + for line in f: + process_line(line=line) + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py index 0880019b3..495f62cb4 100755 --- a/egs/aishell/ASR/local/prepare_lang.py +++ b/egs/aishell/ASR/local/prepare_lang.py @@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following: 5. Generate L_disambig.pt, in k2 format. """ +import argparse import math from collections import defaultdict from pathlib import Path @@ -314,8 +315,14 @@ def lexicon_to_fst( return fsa +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone or data/lang_syllable") + return parser.parse_args() + + def main(): - out_dir = Path("data/lang_phone") + out_dir = Path(get_args().lang_dir) lexicon_filename = out_dir / "lexicon.txt" sil_token = "SIL" sil_prob = 0.5 diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 1e78d79d9..fe8a747dc 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -124,7 +124,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir if [ ! -f $lang_phone_dir/L_disambig.pt ]; then - ./local/prepare_lang.py + ./local/prepare_lang.py --lang-dir $lang_phone_dir fi # Train a bigram P for MMI training @@ -133,7 +133,8 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid find data/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid - awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | + cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt fi if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index 4df826f53..2c7455e3a 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -318,7 +318,7 @@ class AishellAsrDataModule: return valid_dl def test_dataloaders(self, cuts: CutSet) -> DataLoader: - is_list = isinstance(cuts, list) + logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) if self.args.on_the_fly_feats @@ -328,40 +328,27 @@ class AishellAsrDataModule: sampler = BucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False ) - logging.debug("About to create test dataloader") test_dl = DataLoader( test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers, ) + return test_dl - if is_list: - return test_dl - else: - return test_dl[0] @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz") + cuts_train = load_manifest(self.args.manifest_dir / "cuts_train.json.gz") return cuts_train @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz") - return cuts_valid + return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz") @lru_cache() def test_cuts(self) -> List[CutSet]: - test_sets = ["test"] - cuts = [] - for test_set in test_sets: - logging.debug("About to get test cuts") - cuts.append( - load_manifest( - self.args.feature_dir / f"cuts_{test_set}.json.gz" - ) - ) - return cuts + logging.info("About to get test cuts") + return load_manifest(self.args.manifest_dir / f"cuts_test.json.gz") diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py index c41d7da17..aa98700e5 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py @@ -373,7 +373,9 @@ def main(): # if test_set == 'test-clean': continue # test_sets = ["test"] - for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()): + test_dls = [test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py index 410f07c53..a0045115d 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py @@ -553,8 +553,8 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders() - valid_dl = aishell.valid_dataloaders() + train_dl = aishell.train_dataloaders(aishell.train_cuts()) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py index 45118a8bc..9ed9b2ad1 100644 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/transducer_stateless/beam_search.py @@ -22,13 +22,18 @@ import torch from model import Transducer -def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: +def greedy_search( + model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int +) -> List[int]: """ Args: model: An instance of `Transducer`. encoder_out: A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + max_sym_per_frame: + Maximum number of symbols per frame. If it is set to 0, the WER + would be 100%. Returns: Return the decoded result. """ @@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: # Maximum symbols per utterance. max_sym_per_utt = 1000 - # If at frame t, it decodes more than this number of symbols, - # it will move to the next step t+1 - max_sym_per_frame = 3 - # symbols per frame sym_per_frame = 0 @@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: sym_per_utt = 0 while t < T and sym_per_utt < max_sym_per_utt: + if sym_per_frame >= max_sym_per_frame: + sym_per_frame = 0 + t += 1 + continue + # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on @@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: sym_per_utt += 1 sym_per_frame += 1 - - if y == blank_id or sym_per_frame > max_sym_per_frame: + else: sym_per_frame = 0 t += 1 hyp = hyp[context_size:] # remove blanks diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py index 245aaa428..81d7708f9 100644 --- a/egs/aishell/ASR/transducer_stateless/conformer.py +++ b/egs/aishell/ASR/transducer_stateless/conformer.py @@ -56,7 +56,6 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -69,7 +68,6 @@ class Conformer(Transformer): dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -107,11 +105,6 @@ class Conformer(Transformer): - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.norm = nn.BatchNorm1d(channels) + self.norm = nn.LayerNorm(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, @@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index 82175e8db..22640131c 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -15,26 +15,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Usage: -(1) greedy search -./transducer_stateless/decode.py \ - --epoch 14 \ - --avg 7 \ - --exp-dir ./transducer_stateless/exp \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) beam search -./transducer_stateless/decode.py \ - --epoch 14 \ - --avg 7 \ - --exp-dir ./transducer_stateless/exp \ - --max-duration 100 \ - --decoding-method beam_search \ - --beam-size 4 -""" - import argparse import logging @@ -42,18 +22,19 @@ from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple -import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AishellAsrDataModule from beam_search import beam_search, greedy_search from conformer import Conformer from decoder import Decoder from joiner import Joiner from model import Transducer +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -70,7 +51,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=20, + default=30, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) @@ -91,10 +72,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="The lang dir", ) parser.add_argument( @@ -114,6 +95,20 @@ def get_parser(): help="Used only when --decoding-method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="Maximum number of symbols per frame", + ) + return parser @@ -129,9 +124,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -149,7 +141,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -188,7 +179,7 @@ def get_transducer_model(params: AttributeDict): def decode_one_batch( params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + lexicon: Lexicon, batch: dict, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -206,12 +197,12 @@ def decode_one_batch( It's the return value of :func:`get_params`. model: The neural model. - sp: - The BPE model. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + lexicon: + It contains the token symbol table and the word symbol table. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -237,7 +228,11 @@ def decode_one_batch( encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size @@ -246,7 +241,7 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + hyps.append([lexicon.token_table[i] for i in hyp]) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -258,7 +253,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + lexicon: Lexicon, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -269,8 +264,6 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. - sp: - The BPE model. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -297,7 +290,7 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, - sp=sp, + lexicon=lexicon, batch=batch, ) @@ -332,16 +325,19 @@ def save_results( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((list("".join(res[0])), list("".join(res[1])))) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + f, f"{test_set_name}-{key}", results_char, enable_log=True ) test_set_wers[key] = wer @@ -353,11 +349,11 @@ def save_results( / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: - print("settings\tWER", file=f) + print("settings\tCER", file=f) for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: s += "{}\t{}{}\n".format(key, val, note) @@ -368,9 +364,10 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) @@ -381,6 +378,9 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.decoding_method == "beam_search": params.suffix += f"-beam-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -391,12 +391,14 @@ def main(): logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = graph_compiler.texts_to_ids("")[0][0] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) @@ -422,23 +424,19 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + aishell = AishellAsrDataModule(args) + test_cuts = aishell.test_cuts() + test_dl = aishell.test_dataloaders(test_cuts) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_sets = ["test"] + test_dls = [test_dl] - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): + for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( dl=test_dl, params=params, model=model, - sp=sp, + lexicon=lexicon, ) save_results( diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index cedbc937e..dca084477 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -20,13 +20,14 @@ import torch.nn.functional as F class Decoder(nn.Module): - """This class implements the stateless decoder from the following paper: + """This class modifies the stateless decoder from the following paper: RNN-transducer with stateless prediction network https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 It removes the recurrent connection from the decoder, i.e., the prediction - network. + network. Different from the above paper, it adds an extra Conv1d + right after the embedding layer. TODO: Implement https://arxiv.org/pdf/2109.07513.pdf """ diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index a877b5067..641555bdb 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -104,6 +104,14 @@ def get_parser(): """, ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -119,9 +127,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder diff --git a/egs/aishell/ASR/transducer_stateless/joiner.py b/egs/aishell/ASR/transducer_stateless/joiner.py index 0422f8a6f..2ef3f1de6 100644 --- a/egs/aishell/ASR/transducer_stateless/joiner.py +++ b/egs/aishell/ASR/transducer_stateless/joiner.py @@ -16,7 +16,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F class Joiner(nn.Module): @@ -48,7 +47,7 @@ class Joiner(nn.Module): # Now decoder_out is (N, 1, U, C) logit = encoder_out + decoder_out - logit = F.relu(logit) + logit = torch.tanh(logit) output = self.output_linear(logit) diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index 49efa6749..e5dba8f0e 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -110,6 +110,22 @@ def get_parser(): help="Used only when --method is beam_search", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + return parser @@ -126,9 +142,6 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "context_size": 2, # tri-gram "env_info": get_env_info(), } ) @@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -279,7 +291,11 @@ def main(): encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] # fmt: on if params.method == "greedy_search": - hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) elif params.method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index c34fea157..7de38ed41 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -2,6 +2,7 @@ # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang # Mingshuang Luo) +# Copyright 2021 (Pingfeng Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -120,6 +121,14 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + return parser @@ -161,15 +170,10 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - attention_dim: Hidden dim for multi-head attention model. - num_decoder_layers: Number of decoder layer of transformer decoder. - - weight_decay: The weight_decay for the optimizer. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -191,11 +195,7 @@ def get_params() -> AttributeDict: "dim_feedforward": 2048, "num_encoder_layers": 12, "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for decoder - "context_size": 2, # tri-gram # parameters for Noam - "weight_decay": 1e-6, "warm_step": 80000, # For the 100h subset, use 8k "env_info": get_env_info(), } @@ -215,7 +215,6 @@ def get_encoder_model(params: AttributeDict): dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, ) return encoder @@ -556,11 +555,10 @@ def run(rank, world_size, args): graph_compiler = CharCtcTrainingGraphCompiler( lexicon=lexicon, device=device, - sos_token="", - eos_token="", + oov='', ) - params.blank_id = graph_compiler.texts_to_ids("")[0] + params.blank_id = graph_compiler.texts_to_ids("")[0][0] params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) @@ -584,7 +582,6 @@ def run(rank, world_size, args): model_size=params.attention_dim, factor=params.lr_factor, warm_step=params.warm_step, - weight_decay=params.weight_decay, ) if checkpoints and "optimizer" in checkpoints: @@ -611,8 +608,7 @@ def run(rank, world_size, args): logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") train_dl = aishell.train_dataloaders(train_cuts) - - valid_dl = aishell.valid_dataloaders(aishell.dev_cuts()) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py index 814290264..e851dcc32 100644 --- a/egs/aishell/ASR/transducer_stateless/transformer.py +++ b/egs/aishell/ASR/transducer_stateless/transformer.py @@ -39,7 +39,6 @@ class Transformer(EncoderInterface): dropout: float = 0.1, normalize_before: bool = True, vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, ) -> None: """ Args: @@ -65,13 +64,8 @@ class Transformer(EncoderInterface): If True, use pre-layer norm; False to use post-layer norm. vgg_frontend: True to use vgg style frontend for subsampling. - use_feat_batchnorm: - True to use batchnorm for the input layer. """ super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) self.num_features = num_features self.output_dim = output_dim @@ -131,11 +125,6 @@ class Transformer(EncoderInterface): - logit_lens, a tensor of shape (batch_size,) containing the number of frames in `logits` before padding. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - x = self.encoder_embed(x) x = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)