diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md
index 465f8ce85..dd27e1f35 100644
--- a/egs/aishell/ASR/RESULTS.md
+++ b/egs/aishell/ASR/RESULTS.md
@@ -1,5 +1,37 @@
## Results
+### Aishell training results (Transducer-stateless)
+#### 2021-12-29
+(Pingfeng Luo) : The tensorboard log for training is available at
+
+||test|
+|--|--|
+|CER| 5.7% |
+
+You can use the following commands to reproduce our results:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
+./transducer_stateless/train.py \
+ --bucketing-sampler True \
+ --world-size 8 \
+ --lang-dir data/lang_char \
+ --num-epochs 40 \
+ --start-epoch 0 \
+ --exp-dir transducer_stateless/exp_char \
+ --max-duration 160 \
+ --lr-factor 3
+
+./transducer_stateless/decode.py \
+ --epoch 39 \
+ --avg 10 \
+ --lang-dir data/lang_char \
+ --exp-dir transducer_stateless/exp_char \
+ --max-duration 100 \
+ --decoding-method beam_search \
+ --beam-size 4
+```
+
### Aishell training results (Conformer-MMI)
#### 2021-12-04
(Pingfeng Luo): Result of
diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py
index dc593eeb9..c38c4c65f 100755
--- a/egs/aishell/ASR/conformer_ctc/decode.py
+++ b/egs/aishell/ASR/conformer_ctc/decode.py
@@ -538,9 +538,13 @@ def main():
logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args)
+ test_cuts = aishell.test_cuts()
+ test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"]
- for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
+ test_dls = [test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py
index 720ed6c22..542fb0364 100644
--- a/egs/aishell/ASR/conformer_ctc/subsampling.py
+++ b/egs/aishell/ASR/conformer_ctc/subsampling.py
@@ -22,8 +22,8 @@ import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
- Convert an input of shape [N, T, idim] to an output
- with shape [N, T', odim], where
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
@@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
"""
Args:
idim:
- Input dim. The input shape is [N, T, idim].
+ Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
- Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
assert idim >= 7
super().__init__()
@@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
Args:
x:
- Its shape is [N, T, idim].
+ Its shape is (N, T, idim).
Returns:
- Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
- # On entry, x is [N, T, idim]
- x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
+ # On entry, x is (N, T, idim)
+ x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
- # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
+ # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
- # Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
+ # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x
@@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
- Convert an input of shape [N, T, idim] to an output
- with shape [N, T', odim], where
+ Convert an input of shape (N, T, idim) to an output
+ with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
@@ -93,10 +93,10 @@ class VggSubsampling(nn.Module):
Args:
idim:
- Input dim. The input shape is [N, T, idim].
+ Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
- Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
+ Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()
@@ -149,10 +149,10 @@ class VggSubsampling(nn.Module):
Args:
x:
- Its shape is [N, T, idim].
+ Its shape is (N, T, idim).
Returns:
- Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
+ Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)
diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py
index 629d7a373..a4bc8e3bb 100755
--- a/egs/aishell/ASR/conformer_ctc/train.py
+++ b/egs/aishell/ASR/conformer_ctc/train.py
@@ -614,8 +614,8 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"])
aishell = AishellAsrDataModule(args)
- train_dl = aishell.train_dataloaders()
- valid_dl = aishell.valid_dataloaders()
+ train_dl = aishell.train_dataloaders(aishell.train_cuts())
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py
index 1d0b3daad..35a7d98fc 100755
--- a/egs/aishell/ASR/conformer_mmi/decode.py
+++ b/egs/aishell/ASR/conformer_mmi/decode.py
@@ -557,9 +557,13 @@ def main():
logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args)
+ test_cuts = aishell.test_cuts()
+ test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"]
- for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
+ test_dls = [test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py
index 14ddaf5fd..79c16d1cc 100755
--- a/egs/aishell/ASR/conformer_mmi/train.py
+++ b/egs/aishell/ASR/conformer_mmi/train.py
@@ -608,8 +608,9 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"])
aishell = AishellAsrDataModule(args)
- train_dl = aishell.train_dataloaders()
- valid_dl = aishell.valid_dataloaders()
+ train_cuts = aishell.train_cuts()
+ train_dl = aishell.train_dataloaders(train_cuts)
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
diff --git a/egs/aishell/ASR/local/make_syllable_lexicon.py b/egs/aishell/ASR/local/make_syllable_lexicon.py
new file mode 100755
index 000000000..15c0f8ac0
--- /dev/null
+++ b/egs/aishell/ASR/local/make_syllable_lexicon.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python3
+
+# Copyright 2021 (Author: Pingfeng Luo)
+"""
+ make syllables lexicon and handle heteronym
+"""
+import argparse
+from pathlib import Path
+from pypinyin import pinyin, lazy_pinyin, Style
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
+ return parser.parse_args()
+
+
+def process_line(
+ line: str
+) -> None:
+ """
+ Args:
+ line:
+ A line of transcript consisting of space(s) separated word and phones
+ input :
+ 你好 n i3 h ao3
+ 晴天 q ing2 t ian1
+
+ output :
+ 你好 ni3 hao3
+ 晴天 qing2 tian1
+ Returns:
+ Return None.
+ """
+ chars = line.strip().split()[0]
+ pinyins = pinyin(chars, style=Style.TONE3, heteronym=True)
+ word_syllables = []
+ word_syllables_num = 1
+ inited = False
+ for char_syllables in pinyins :
+ new_char_syllables_num = len(char_syllables)
+ if not inited and len(char_syllables) :
+ word_syllables = [char_syllables[0]]
+ inited = True
+ elif new_char_syllables_num == 1 :
+ for i in range(word_syllables_num) :
+ word_syllables[i] += " " + str(char_syllables)
+ elif new_char_syllables_num > 1 :
+ word_syllables = word_syllables * new_char_syllables_num
+ for pre_index in range(word_syllables_num) :
+ for expand_index in range(new_char_syllables_num) :
+ word_syllables[pre_index * new_char_syllables_num + expand_index] += " " + char_syllables[expand_index]
+ word_syllables_num *= new_char_syllables_num
+
+ for word_syallable in word_syllables :
+ print("{} {}".format(chars.strip(), str(word_syallable).strip()))
+
+
+def main():
+ args = get_args()
+ assert Path(args.lexicon).is_file()
+
+ with open(args.lexicon) as f:
+ for line in f:
+ process_line(line=line)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/local/prepare_lang.py b/egs/aishell/ASR/local/prepare_lang.py
index 0880019b3..495f62cb4 100755
--- a/egs/aishell/ASR/local/prepare_lang.py
+++ b/egs/aishell/ASR/local/prepare_lang.py
@@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format.
"""
+import argparse
import math
from collections import defaultdict
from pathlib import Path
@@ -314,8 +315,14 @@ def lexicon_to_fst(
return fsa
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone or data/lang_syllable")
+ return parser.parse_args()
+
+
def main():
- out_dir = Path("data/lang_phone")
+ out_dir = Path(get_args().lang_dir)
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index 1e78d79d9..fe8a747dc 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -124,7 +124,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
- ./local/prepare_lang.py
+ ./local/prepare_lang.py --lang-dir $lang_phone_dir
fi
# Train a bigram P for MMI training
@@ -133,7 +133,8 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid
find data/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid
- awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
+ awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text |
+ cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
fi
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 4df826f53..2c7455e3a 100644
--- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -318,7 +318,7 @@ class AishellAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
- is_list = isinstance(cuts, list)
+ logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
@@ -328,40 +328,27 @@ class AishellAsrDataModule:
sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
- logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
+ return test_dl
- if is_list:
- return test_dl
- else:
- return test_dl[0]
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
- cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
+ cuts_train = load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
- cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz")
- return cuts_valid
+ return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
@lru_cache()
def test_cuts(self) -> List[CutSet]:
- test_sets = ["test"]
- cuts = []
- for test_set in test_sets:
- logging.debug("About to get test cuts")
- cuts.append(
- load_manifest(
- self.args.feature_dir / f"cuts_{test_set}.json.gz"
- )
- )
- return cuts
+ logging.info("About to get test cuts")
+ return load_manifest(self.args.manifest_dir / f"cuts_test.json.gz")
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
index c41d7da17..aa98700e5 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
@@ -373,7 +373,9 @@ def main():
# if test_set == 'test-clean': continue
#
test_sets = ["test"]
- for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
+ test_dls = [test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
index 410f07c53..a0045115d 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py
@@ -553,8 +553,8 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
aishell = AishellAsrDataModule(args)
- train_dl = aishell.train_dataloaders()
- valid_dl = aishell.valid_dataloaders()
+ train_dl = aishell.train_dataloaders(aishell.train_cuts())
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py
index 45118a8bc..9ed9b2ad1 100644
--- a/egs/aishell/ASR/transducer_stateless/beam_search.py
+++ b/egs/aishell/ASR/transducer_stateless/beam_search.py
@@ -22,13 +22,18 @@ import torch
from model import Transducer
-def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
+def greedy_search(
+ model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
+) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
+ max_sym_per_frame:
+ Maximum number of symbols per frame. If it is set to 0, the WER
+ would be 100%.
Returns:
Return the decoded result.
"""
@@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# Maximum symbols per utterance.
max_sym_per_utt = 1000
- # If at frame t, it decodes more than this number of symbols,
- # it will move to the next step t+1
- max_sym_per_frame = 3
-
# symbols per frame
sym_per_frame = 0
@@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt:
+ if sym_per_frame >= max_sym_per_frame:
+ sym_per_frame = 0
+ t += 1
+ continue
+
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
@@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt += 1
sym_per_frame += 1
-
- if y == blank_id or sym_per_frame > max_sym_per_frame:
+ else:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks
diff --git a/egs/aishell/ASR/transducer_stateless/conformer.py b/egs/aishell/ASR/transducer_stateless/conformer.py
index 245aaa428..81d7708f9 100644
--- a/egs/aishell/ASR/transducer_stateless/conformer.py
+++ b/egs/aishell/ASR/transducer_stateless/conformer.py
@@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
- use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
- use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
- if self.use_feat_batchnorm:
- x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
- x = self.feat_batchnorm(x)
- x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
-
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels,
bias=bias,
)
- self.norm = nn.BatchNorm1d(channels)
+ self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
@@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv
x = self.depthwise_conv(x)
- x = self.activation(self.norm(x))
+ # x is (batch, channels, time)
+ x = x.permute(0, 2, 1)
+ x = self.norm(x)
+ x = x.permute(0, 2, 1)
+
+ x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index 82175e8db..22640131c 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -15,26 +15,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Usage:
-(1) greedy search
-./transducer_stateless/decode.py \
- --epoch 14 \
- --avg 7 \
- --exp-dir ./transducer_stateless/exp \
- --max-duration 100 \
- --decoding-method greedy_search
-
-(2) beam search
-./transducer_stateless/decode.py \
- --epoch 14 \
- --avg 7 \
- --exp-dir ./transducer_stateless/exp \
- --max-duration 100 \
- --decoding-method beam_search \
- --beam-size 4
-"""
-
import argparse
import logging
@@ -42,18 +22,19 @@ from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
-import sentencepiece as spm
import torch
import torch.nn as nn
-from asr_datamodule import LibriSpeechAsrDataModule
+from asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
+from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
@@ -70,7 +51,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
- default=20,
+ default=30,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
@@ -91,10 +72,10 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--lang-dir",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_char",
+ help="The lang dir",
)
parser.add_argument(
@@ -114,6 +95,20 @@ def get_parser():
help="Used only when --decoding-method is beam_search",
)
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=3,
+ help="Maximum number of symbols per frame",
+ )
+
return parser
@@ -129,9 +124,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
- "use_feat_batchnorm": True,
- # parameters for decoder
- "context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@@ -149,7 +141,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
- use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@@ -188,7 +179,7 @@ def get_transducer_model(params: AttributeDict):
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
- sp: spm.SentencePieceProcessor,
+ lexicon: Lexicon,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@@ -206,12 +197,12 @@ def decode_one_batch(
It's the return value of :func:`get_params`.
model:
The neural model.
- sp:
- The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
+ lexicon:
+ It contains the token symbol table and the word symbol table.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@@ -237,7 +228,11 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
- hyp = greedy_search(model=model, encoder_out=encoder_out_i)
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
@@ -246,7 +241,7 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
- hyps.append(sp.decode(hyp).split())
+ hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
@@ -258,7 +253,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
- sp: spm.SentencePieceProcessor,
+ lexicon: Lexicon,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@@ -269,8 +264,6 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
- sp:
- The BPE model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@@ -297,7 +290,7 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
- sp=sp,
+ lexicon=lexicon,
batch=batch,
)
@@ -332,16 +325,19 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
- logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
+ # we compute CER for aishell dataset.
+ results_char = []
+ for res in results:
+ results_char.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results, enable_log=True
+ f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
@@ -353,11 +349,11 @@ def save_results(
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
- print("settings\tWER", file=f)
+ print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
- s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
@@ -368,9 +364,10 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
- LibriSpeechAsrDataModule.add_arguments(parser)
+ AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
@@ -381,6 +378,9 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@@ -391,12 +391,14 @@ def main():
logging.info(f"Device: {device}")
- sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ params.blank_id = graph_compiler.texts_to_ids("")[0][0]
+ params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@@ -422,23 +424,19 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
- librispeech = LibriSpeechAsrDataModule(args)
+ aishell = AishellAsrDataModule(args)
+ test_cuts = aishell.test_cuts()
+ test_dl = aishell.test_dataloaders(test_cuts)
- test_clean_cuts = librispeech.test_clean_cuts()
- test_other_cuts = librispeech.test_other_cuts()
+ test_sets = ["test"]
+ test_dls = [test_dl]
- test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
- test_other_dl = librispeech.test_dataloaders(test_other_cuts)
-
- test_sets = ["test-clean", "test-other"]
- test_dl = [test_clean_dl, test_other_dl]
-
- for test_set, test_dl in zip(test_sets, test_dl):
+ for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
- sp=sp,
+ lexicon=lexicon,
)
save_results(
diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py
index cedbc937e..dca084477 100644
--- a/egs/aishell/ASR/transducer_stateless/decoder.py
+++ b/egs/aishell/ASR/transducer_stateless/decoder.py
@@ -20,13 +20,14 @@ import torch.nn.functional as F
class Decoder(nn.Module):
- """This class implements the stateless decoder from the following paper:
+ """This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
- network.
+ network. Different from the above paper, it adds an extra Conv1d
+ right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py
index a877b5067..641555bdb 100755
--- a/egs/aishell/ASR/transducer_stateless/export.py
+++ b/egs/aishell/ASR/transducer_stateless/export.py
@@ -104,6 +104,14 @@ def get_parser():
""",
)
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
return parser
@@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
- "use_feat_batchnorm": True,
- # parameters for decoder
- "context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
- use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
diff --git a/egs/aishell/ASR/transducer_stateless/joiner.py b/egs/aishell/ASR/transducer_stateless/joiner.py
index 0422f8a6f..2ef3f1de6 100644
--- a/egs/aishell/ASR/transducer_stateless/joiner.py
+++ b/egs/aishell/ASR/transducer_stateless/joiner.py
@@ -16,7 +16,6 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
class Joiner(nn.Module):
@@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
- logit = F.relu(logit)
+ logit = torch.tanh(logit)
output = self.output_linear(logit)
diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py
index 49efa6749..e5dba8f0e 100755
--- a/egs/aishell/ASR/transducer_stateless/pretrained.py
+++ b/egs/aishell/ASR/transducer_stateless/pretrained.py
@@ -110,6 +110,22 @@ def get_parser():
help="Used only when --method is beam_search",
)
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=3,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
return parser
@@ -126,9 +142,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
- "use_feat_batchnorm": True,
- # parameters for decoder
- "context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
- use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@@ -279,7 +291,11 @@ def main():
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
- hyp = greedy_search(model=model, encoder_out=encoder_out_i)
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py
index c34fea157..7de38ed41 100755
--- a/egs/aishell/ASR/transducer_stateless/train.py
+++ b/egs/aishell/ASR/transducer_stateless/train.py
@@ -2,6 +2,7 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
+# Copyright 2021 (Pingfeng Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -120,6 +121,14 @@ def get_parser():
help="The lr_factor for Noam optimizer",
)
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; "
+ "2 means tri-gram",
+ )
+
return parser
@@ -161,15 +170,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- - use_feat_batchnorm: Whether to do batch normalization for the
- input features.
-
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- - weight_decay: The weight_decay for the optimizer.
-
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
@@ -191,11 +195,7 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
- "use_feat_batchnorm": True,
- # parameters for decoder
- "context_size": 2, # tri-gram
# parameters for Noam
- "weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
@@ -215,7 +215,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
- use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@@ -556,11 +555,10 @@ def run(rank, world_size, args):
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
- sos_token="",
- eos_token="",
+ oov='',
)
- params.blank_id = graph_compiler.texts_to_ids("")[0]
+ params.blank_id = graph_compiler.texts_to_ids("")[0][0]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@@ -584,7 +582,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
- weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints:
@@ -611,8 +608,7 @@ def run(rank, world_size, args):
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = aishell.train_dataloaders(train_cuts)
-
- valid_dl = aishell.valid_dataloaders(aishell.dev_cuts())
+ valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
diff --git a/egs/aishell/ASR/transducer_stateless/transformer.py b/egs/aishell/ASR/transducer_stateless/transformer.py
index 814290264..e851dcc32 100644
--- a/egs/aishell/ASR/transducer_stateless/transformer.py
+++ b/egs/aishell/ASR/transducer_stateless/transformer.py
@@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
- use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
@@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
- use_feat_batchnorm:
- True to use batchnorm for the input layer.
"""
super().__init__()
- self.use_feat_batchnorm = use_feat_batchnorm
- if use_feat_batchnorm:
- self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.output_dim = output_dim
@@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
- if self.use_feat_batchnorm:
- x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
- x = self.feat_batchnorm(x)
- x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
-
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)