add transducer_stateless with char unit to AIShell

This commit is contained in:
PingFeng Luo 2021-12-29 15:34:50 +08:00
parent 234307f33a
commit 18cdea4745
21 changed files with 275 additions and 163 deletions

View File

@ -1,5 +1,37 @@
## Results ## Results
### Aishell training results (Transducer-stateless)
#### 2021-12-29
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/sPEDmAQ3QcWuDAWGiKprVg/>
||test|
|--|--|
|CER| 5.7% |
You can use the following commands to reproduce our results:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
./transducer_stateless/train.py \
--bucketing-sampler True \
--world-size 8 \
--lang-dir data/lang_char \
--num-epochs 40 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp_char \
--max-duration 160 \
--lr-factor 3
./transducer_stateless/decode.py \
--epoch 39 \
--avg 10 \
--lang-dir data/lang_char \
--exp-dir transducer_stateless/exp_char \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
```
### Aishell training results (Conformer-MMI) ### Aishell training results (Conformer-MMI)
#### 2021-12-04 #### 2021-12-04
(Pingfeng Luo): Result of <https://github.com/k2-fsa/icefall/pull/140> (Pingfeng Luo): Result of <https://github.com/k2-fsa/icefall/pull/140>

View File

@ -538,9 +538,13 @@ def main():
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"] test_sets = ["test"]
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()): test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -22,8 +22,8 @@ import torch.nn as nn
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length). """Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output Convert an input of shape (N, T, idim) to an output
with shape [N, T', odim], where with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on It is based on
@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
""" """
Args: Args:
idim: idim:
Input dim. The input shape is [N, T, idim]. Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7 Caution: It requires: T >=7, idim >=7
odim: odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
""" """
assert idim >= 7 assert idim >= 7
super().__init__() super().__init__()
@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
Args: Args:
x: x:
Its shape is [N, T, idim]. Its shape is (N, T, idim).
Returns: Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
""" """
# On entry, x is [N, T, idim] # On entry, x is (N, T, idim)
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x) x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size() b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim] # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x return x
@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
This paper is not 100% explicit so I am guessing to some extent, This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations. and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output Convert an input of shape (N, T, idim) to an output
with shape [N, T', odim], where with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
""" """
@ -93,10 +93,10 @@ class VggSubsampling(nn.Module):
Args: Args:
idim: idim:
Input dim. The input shape is [N, T, idim]. Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7 Caution: It requires: T >=7, idim >=7
odim: odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
""" """
super().__init__() super().__init__()
@ -149,10 +149,10 @@ class VggSubsampling(nn.Module):
Args: Args:
x: x:
Its shape is [N, T, idim]. Its shape is (N, T, idim).
Returns: Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
""" """
x = x.unsqueeze(1) x = x.unsqueeze(1)
x = self.layers(x) x = self.layers(x)

View File

@ -614,8 +614,8 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
train_dl = aishell.train_dataloaders() train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders() valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)

View File

@ -557,9 +557,13 @@ def main():
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"] test_sets = ["test"]
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()): test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -608,8 +608,9 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
train_dl = aishell.train_dataloaders() train_cuts = aishell.train_cuts()
valid_dl = aishell.valid_dataloaders() train_dl = aishell.train_dataloaders(train_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)

View File

@ -0,0 +1,68 @@
#!/usr/bin/env python3
# Copyright 2021 (Author: Pingfeng Luo)
"""
make syllables lexicon and handle heteronym
"""
import argparse
from pathlib import Path
from pypinyin import pinyin, lazy_pinyin, Style
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
return parser.parse_args()
def process_line(
line: str
) -> None:
"""
Args:
line:
A line of transcript consisting of space(s) separated word and phones
input :
你好 n i3 h ao3
晴天 q ing2 t ian1
output :
你好 ni3 hao3
晴天 qing2 tian1
Returns:
Return None.
"""
chars = line.strip().split()[0]
pinyins = pinyin(chars, style=Style.TONE3, heteronym=True)
word_syllables = []
word_syllables_num = 1
inited = False
for char_syllables in pinyins :
new_char_syllables_num = len(char_syllables)
if not inited and len(char_syllables) :
word_syllables = [char_syllables[0]]
inited = True
elif new_char_syllables_num == 1 :
for i in range(word_syllables_num) :
word_syllables[i] += " " + str(char_syllables)
elif new_char_syllables_num > 1 :
word_syllables = word_syllables * new_char_syllables_num
for pre_index in range(word_syllables_num) :
for expand_index in range(new_char_syllables_num) :
word_syllables[pre_index * new_char_syllables_num + expand_index] += " " + char_syllables[expand_index]
word_syllables_num *= new_char_syllables_num
for word_syallable in word_syllables :
print("{} {}".format(chars.strip(), str(word_syallable).strip()))
def main():
args = get_args()
assert Path(args.lexicon).is_file()
with open(args.lexicon) as f:
for line in f:
process_line(line=line)
if __name__ == "__main__":
main()

View File

@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.
""" """
import argparse
import math import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -314,8 +315,14 @@ def lexicon_to_fst(
return fsa return fsa
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone or data/lang_syllable")
return parser.parse_args()
def main(): def main():
out_dir = Path("data/lang_phone") out_dir = Path(get_args().lang_dir)
lexicon_filename = out_dir / "lexicon.txt" lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL" sil_token = "SIL"
sil_prob = 0.5 sil_prob = 0.5

View File

@ -124,7 +124,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir ./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
if [ ! -f $lang_phone_dir/L_disambig.pt ]; then if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
./local/prepare_lang.py ./local/prepare_lang.py --lang-dir $lang_phone_dir
fi fi
# Train a bigram P for MMI training # Train a bigram P for MMI training
@ -133,7 +133,8 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid
find data/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid find data/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text |
cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
fi fi
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then

View File

@ -318,7 +318,7 @@ class AishellAsrDataModule:
return valid_dl return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
is_list = isinstance(cuts, list) logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
@ -328,40 +328,27 @@ class AishellAsrDataModule:
sampler = BucketingSampler( sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False cuts, max_duration=self.args.max_duration, shuffle=False
) )
logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, test,
batch_size=None, batch_size=None,
sampler=sampler, sampler=sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
) )
if is_list:
return test_dl return test_dl
else:
return test_dl[0]
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz") cuts_train = load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
return cuts_train return cuts_train
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz") return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
return cuts_valid
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
test_sets = ["test"] logging.info("About to get test cuts")
cuts = [] return load_manifest(self.args.manifest_dir / f"cuts_test.json.gz")
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts

View File

@ -373,7 +373,9 @@ def main():
# if test_set == 'test-clean': continue # if test_set == 'test-clean': continue
# #
test_sets = ["test"] test_sets = ["test"]
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()): test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -553,8 +553,8 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
aishell = AishellAsrDataModule(args) aishell = AishellAsrDataModule(args)
train_dl = aishell.train_dataloaders() train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders() valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)

View File

@ -22,13 +22,18 @@ import torch
from model import Transducer from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
""" """
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# Maximum symbols per utterance. # Maximum symbols per utterance.
max_sym_per_utt = 1000 max_sym_per_utt = 1000
# If at frame t, it decodes more than this number of symbols,
# it will move to the next step t+1
max_sym_per_frame = 3
# symbols per frame # symbols per frame
sym_per_frame = 0 sym_per_frame = 0
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt = 0 sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on # fmt: on
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt += 1 sym_per_utt += 1
sym_per_frame += 1 sym_per_frame += 1
else:
if y == blank_id or sym_per_frame > max_sym_per_frame:
sym_per_frame = 0 sym_per_frame = 0
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks hyp = hyp[context_size:] # remove blanks

View File

@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout, dropout=dropout,
normalize_before=normalize_before, normalize_before=normalize_before,
vgg_frontend=vgg_frontend, vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
) )
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
""" """
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.BatchNorm1d(channels) self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.activation(self.norm(x)) # x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -15,26 +15,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
"""
import argparse import argparse
import logging import logging
@ -42,18 +22,19 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from model import Transducer from model import Transducer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -70,7 +51,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=20, default=30,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
@ -91,10 +72,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="The lang dir",
) )
parser.add_argument( parser.add_argument(
@ -114,6 +95,20 @@ def get_parser():
help="Used only when --decoding-method is beam_search", help="Used only when --decoding-method is beam_search",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser return parser
@ -129,9 +124,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -149,7 +141,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -188,7 +179,7 @@ def get_transducer_model(params: AttributeDict):
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, lexicon: Lexicon,
batch: dict, batch: dict,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -206,12 +197,12 @@ def decode_one_batch(
It's the return value of :func:`get_params`. It's the return value of :func:`get_params`.
model: model:
The neural model. The neural model.
sp:
The BPE model.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -237,7 +228,11 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i) hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
@ -246,7 +241,7 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -258,7 +253,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, lexicon: Lexicon,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -269,8 +264,6 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
sp:
The BPE model.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -297,7 +290,7 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
sp=sp, lexicon=lexicon,
batch=batch, batch=batch,
) )
@ -332,16 +325,19 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = ( errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True f, f"{test_set_name}-{key}", results_char, enable_log=True
) )
test_set_wers[key] = wer test_set_wers[key] = wer
@ -353,11 +349,11 @@ def save_results(
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tCER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f) print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name) s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name) note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers: for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note) s += "{}\t{}{}\n".format(key, val, note)
@ -368,9 +364,10 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -381,6 +378,9 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search": if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -391,12 +391,14 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
# <blk> is defined in local/train_bpe_model.py params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.blank_id = sp.piece_to_id("<blk>") params.vocab_size = max(lexicon.tokens) + 1
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -422,23 +424,19 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args) aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_clean_cuts = librispeech.test_clean_cuts() test_sets = ["test"]
test_other_cuts = librispeech.test_other_cuts() test_dls = [test_dl]
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) for test_set, test_dl in zip(test_sets, test_dls):
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
sp=sp, lexicon=lexicon,
) )
save_results( save_results(

View File

@ -20,13 +20,14 @@ import torch.nn.functional as F
class Decoder(nn.Module): class Decoder(nn.Module):
"""This class implements the stateless decoder from the following paper: """This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction It removes the recurrent connection from the decoder, i.e., the prediction
network. network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
""" """

View File

@ -104,6 +104,14 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser return parser
@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder

View File

@ -16,7 +16,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module): class Joiner(nn.Module):
@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C) # Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = F.relu(logit) logit = torch.tanh(logit)
output = self.output_linear(logit) output = self.output_linear(logit)

View File

@ -110,6 +110,22 @@ def get_parser():
help="Used only when --method is beam_search", help="Used only when --method is beam_search",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser return parser
@ -126,9 +142,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -279,7 +291,11 @@ def main():
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.method == "greedy_search": if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i) hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search": elif params.method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size

View File

@ -2,6 +2,7 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang # Wei Kang
# Mingshuang Luo) # Mingshuang Luo)
# Copyright 2021 (Pingfeng Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -120,6 +121,14 @@ def get_parser():
help="The lr_factor for Noam optimizer", help="The lr_factor for Noam optimizer",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser return parser
@ -161,15 +170,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model. - attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
""" """
params = AttributeDict( params = AttributeDict(
@ -191,11 +195,7 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k "warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -215,7 +215,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -556,11 +555,10 @@ def run(rank, world_size, args):
graph_compiler = CharCtcTrainingGraphCompiler( graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon, lexicon=lexicon,
device=device, device=device,
sos_token="<sos/eos>", oov='<unk>',
eos_token="<sos/eos>",
) )
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0] params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
logging.info(params) logging.info(params)
@ -584,7 +582,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim, model_size=params.attention_dim,
factor=params.lr_factor, factor=params.lr_factor,
warm_step=params.warm_step, warm_step=params.warm_step,
weight_decay=params.weight_decay,
) )
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:
@ -611,8 +608,7 @@ def run(rank, world_size, args):
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = aishell.train_dataloaders(train_cuts) train_dl = aishell.train_dataloaders(train_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
valid_dl = aishell.valid_dataloaders(aishell.dev_cuts())
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)

View File

@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm. If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend: vgg_frontend:
True to use vgg style frontend for subsampling. True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
""" """
super().__init__() super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features self.num_features = num_features
self.output_dim = output_dim self.output_dim = output_dim
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
""" """
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x = self.encoder_pos(x) x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)