mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add transducer_stateless with char unit to AIShell
This commit is contained in:
parent
234307f33a
commit
18cdea4745
@ -1,5 +1,37 @@
|
||||
## Results
|
||||
|
||||
### Aishell training results (Transducer-stateless)
|
||||
#### 2021-12-29
|
||||
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/sPEDmAQ3QcWuDAWGiKprVg/>
|
||||
|
||||
||test|
|
||||
|--|--|
|
||||
|CER| 5.7% |
|
||||
|
||||
You can use the following commands to reproduce our results:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
|
||||
./transducer_stateless/train.py \
|
||||
--bucketing-sampler True \
|
||||
--world-size 8 \
|
||||
--lang-dir data/lang_char \
|
||||
--num-epochs 40 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp_char \
|
||||
--max-duration 160 \
|
||||
--lr-factor 3
|
||||
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch 39 \
|
||||
--avg 10 \
|
||||
--lang-dir data/lang_char \
|
||||
--exp-dir transducer_stateless/exp_char \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
### Aishell training results (Conformer-MMI)
|
||||
#### 2021-12-04
|
||||
(Pingfeng Luo): Result of <https://github.com/k2-fsa/icefall/pull/140>
|
||||
|
@ -538,9 +538,13 @@ def main():
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
|
||||
test_dls = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
|
@ -22,8 +22,8 @@ import torch.nn as nn
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Input dim. The input shape is (N, T, idim).
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
# On entry, x is [N, T, idim]
|
||||
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
return x
|
||||
|
||||
|
||||
@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
|
||||
This paper is not 100% explicit so I am guessing to some extent,
|
||||
and trying to compare with other VGG implementations.
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
||||
"""
|
||||
|
||||
@ -93,10 +93,10 @@ class VggSubsampling(nn.Module):
|
||||
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Input dim. The input shape is (N, T, idim).
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -149,10 +149,10 @@ class VggSubsampling(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layers(x)
|
||||
|
@ -614,8 +614,8 @@ def run(rank, world_size, args):
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
train_dl = aishell.train_dataloaders()
|
||||
valid_dl = aishell.valid_dataloaders()
|
||||
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
@ -557,9 +557,13 @@ def main():
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
||||
test_sets = ["test"]
|
||||
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
|
||||
test_dls = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
|
@ -608,8 +608,9 @@ def run(rank, world_size, args):
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
train_dl = aishell.train_dataloaders()
|
||||
valid_dl = aishell.valid_dataloaders()
|
||||
train_cuts = aishell.train_cuts()
|
||||
train_dl = aishell.train_dataloaders(train_cuts)
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
68
egs/aishell/ASR/local/make_syllable_lexicon.py
Executable file
68
egs/aishell/ASR/local/make_syllable_lexicon.py
Executable file
@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 (Author: Pingfeng Luo)
|
||||
"""
|
||||
make syllables lexicon and handle heteronym
|
||||
"""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from pypinyin import pinyin, lazy_pinyin, Style
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def process_line(
|
||||
line: str
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
line:
|
||||
A line of transcript consisting of space(s) separated word and phones
|
||||
input :
|
||||
你好 n i3 h ao3
|
||||
晴天 q ing2 t ian1
|
||||
|
||||
output :
|
||||
你好 ni3 hao3
|
||||
晴天 qing2 tian1
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
chars = line.strip().split()[0]
|
||||
pinyins = pinyin(chars, style=Style.TONE3, heteronym=True)
|
||||
word_syllables = []
|
||||
word_syllables_num = 1
|
||||
inited = False
|
||||
for char_syllables in pinyins :
|
||||
new_char_syllables_num = len(char_syllables)
|
||||
if not inited and len(char_syllables) :
|
||||
word_syllables = [char_syllables[0]]
|
||||
inited = True
|
||||
elif new_char_syllables_num == 1 :
|
||||
for i in range(word_syllables_num) :
|
||||
word_syllables[i] += " " + str(char_syllables)
|
||||
elif new_char_syllables_num > 1 :
|
||||
word_syllables = word_syllables * new_char_syllables_num
|
||||
for pre_index in range(word_syllables_num) :
|
||||
for expand_index in range(new_char_syllables_num) :
|
||||
word_syllables[pre_index * new_char_syllables_num + expand_index] += " " + char_syllables[expand_index]
|
||||
word_syllables_num *= new_char_syllables_num
|
||||
|
||||
for word_syallable in word_syllables :
|
||||
print("{} {}".format(chars.strip(), str(word_syallable).strip()))
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert Path(args.lexicon).is_file()
|
||||
|
||||
with open(args.lexicon) as f:
|
||||
for line in f:
|
||||
process_line(line=line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@ -314,8 +315,14 @@ def lexicon_to_fst(
|
||||
return fsa
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone or data/lang_syllable")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path("data/lang_phone")
|
||||
out_dir = Path(get_args().lang_dir)
|
||||
lexicon_filename = out_dir / "lexicon.txt"
|
||||
sil_token = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
@ -124,7 +124,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
|
||||
|
||||
if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py
|
||||
./local/prepare_lang.py --lang-dir $lang_phone_dir
|
||||
fi
|
||||
|
||||
# Train a bigram P for MMI training
|
||||
@ -133,7 +133,8 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||
aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid
|
||||
find data/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid
|
||||
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
|
||||
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text |
|
||||
cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then
|
||||
|
@ -318,7 +318,7 @@ class AishellAsrDataModule:
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
is_list = isinstance(cuts, list)
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
@ -328,40 +328,27 @@ class AishellAsrDataModule:
|
||||
sampler = BucketingSampler(
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
if is_list:
|
||||
return test_dl
|
||||
else:
|
||||
return test_dl[0]
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
|
||||
cuts_train = load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz")
|
||||
return cuts_valid
|
||||
return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> List[CutSet]:
|
||||
test_sets = ["test"]
|
||||
cuts = []
|
||||
for test_set in test_sets:
|
||||
logging.debug("About to get test cuts")
|
||||
cuts.append(
|
||||
load_manifest(
|
||||
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest(self.args.manifest_dir / f"cuts_test.json.gz")
|
||||
|
@ -373,7 +373,9 @@ def main():
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
test_sets = ["test"]
|
||||
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
|
||||
test_dls = [test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
|
@ -553,8 +553,8 @@ def run(rank, world_size, args):
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
train_dl = aishell.train_dataloaders()
|
||||
valid_dl = aishell.valid_dataloaders()
|
||||
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
@ -22,13 +22,18 @@ import torch
|
||||
from model import Transducer
|
||||
|
||||
|
||||
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
def greedy_search(
|
||||
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||
) -> List[int]:
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
max_sym_per_frame:
|
||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||
would be 100%.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
# Maximum symbols per utterance.
|
||||
max_sym_per_utt = 1000
|
||||
|
||||
# If at frame t, it decodes more than this number of symbols,
|
||||
# it will move to the next step t+1
|
||||
max_sym_per_frame = 3
|
||||
|
||||
# symbols per frame
|
||||
sym_per_frame = 0
|
||||
|
||||
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
sym_per_utt = 0
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
if sym_per_frame >= max_sym_per_frame:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# fmt: on
|
||||
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
|
||||
sym_per_utt += 1
|
||||
sym_per_frame += 1
|
||||
|
||||
if y == blank_id or sym_per_frame > max_sym_per_frame:
|
||||
else:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
hyp = hyp[context_size:] # remove blanks
|
||||
|
@ -56,7 +56,6 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
@ -69,7 +68,6 @@ class Conformer(Transformer):
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
use_feat_batchnorm=use_feat_batchnorm,
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
@ -107,11 +105,6 @@ class Conformer(Transformer):
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.norm = nn.LayerNorm(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
# x is (batch, channels, time)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
|
@ -15,26 +15,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch 14 \
|
||||
--avg 7 \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -42,18 +22,19 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from beam_search import beam_search, greedy_search
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from model import Transducer
|
||||
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
@ -70,7 +51,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
default=30,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
@ -91,10 +72,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -114,6 +95,20 @@ def get_parser():
|
||||
help="Used only when --decoding-method is beam_search",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of symbols per frame",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -129,9 +124,6 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"context_size": 2, # tri-gram
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -149,7 +141,6 @@ def get_encoder_model(params: AttributeDict):
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -188,7 +179,7 @@ def get_transducer_model(params: AttributeDict):
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
lexicon: Lexicon,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -206,12 +197,12 @@ def decode_one_batch(
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains the token symbol table and the word symbol table.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -237,7 +228,11 @@ def decode_one_batch(
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
@ -246,7 +241,7 @@ def decode_one_batch(
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
hyps.append([lexicon.token_table[i] for i in hyp])
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
@ -258,7 +253,7 @@ def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
lexicon: Lexicon,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -269,8 +264,6 @@ def decode_dataset(
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -297,7 +290,7 @@ def decode_dataset(
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
lexicon=lexicon,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -332,16 +325,19 @@ def save_results(
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
@ -353,11 +349,11 @@ def save_results(
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
@ -368,9 +364,10 @@ def save_results(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
@ -381,6 +378,9 @@ def main():
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.decoding_method == "beam_search":
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -391,12 +391,14 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -422,23 +424,19 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
aishell = AishellAsrDataModule(args)
|
||||
test_cuts = aishell.test_cuts()
|
||||
test_dl = aishell.test_dataloaders(test_cuts)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
test_sets = ["test"]
|
||||
test_dls = [test_dl]
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
lexicon=lexicon,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -20,13 +20,14 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class implements the stateless decoder from the following paper:
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network.
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
@ -104,6 +104,14 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"context_size": 2, # tri-gram
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
@ -48,7 +47,7 @@ class Joiner(nn.Module):
|
||||
# Now decoder_out is (N, 1, U, C)
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
logit = F.relu(logit)
|
||||
logit = torch.tanh(logit)
|
||||
|
||||
output = self.output_linear(logit)
|
||||
|
||||
|
@ -110,6 +110,22 @@ def get_parser():
|
||||
help="Used only when --method is beam_search",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=3,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -126,9 +142,6 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"context_size": 2, # tri-gram
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -279,7 +291,11 @@ def main():
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
|
@ -2,6 +2,7 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang
|
||||
# Mingshuang Luo)
|
||||
# Copyright 2021 (Pingfeng Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -120,6 +121,14 @@ def get_parser():
|
||||
help="The lr_factor for Noam optimizer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -161,15 +170,10 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
||||
input features.
|
||||
|
||||
- attention_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- weight_decay: The weight_decay for the optimizer.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
@ -191,11 +195,7 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"context_size": 2, # tri-gram
|
||||
# parameters for Noam
|
||||
"weight_decay": 1e-6,
|
||||
"warm_step": 80000, # For the 100h subset, use 8k
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
@ -215,7 +215,6 @@ def get_encoder_model(params: AttributeDict):
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -556,11 +555,10 @@ def run(rank, world_size, args):
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
oov='<unk>',
|
||||
)
|
||||
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0]
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
@ -584,7 +582,6 @@ def run(rank, world_size, args):
|
||||
model_size=params.attention_dim,
|
||||
factor=params.lr_factor,
|
||||
warm_step=params.warm_step,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
@ -611,8 +608,7 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||
|
||||
train_dl = aishell.train_dataloaders(train_cuts)
|
||||
|
||||
valid_dl = aishell.valid_dataloaders(aishell.dev_cuts())
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
|
||||
dropout: float = 0.1,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
|
||||
If True, use pre-layer norm; False to use post-layer norm.
|
||||
vgg_frontend:
|
||||
True to use vgg style frontend for subsampling.
|
||||
use_feat_batchnorm:
|
||||
True to use batchnorm for the input layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.use_feat_batchnorm = use_feat_batchnorm
|
||||
if use_feat_batchnorm:
|
||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
||||
|
||||
self.num_features = num_features
|
||||
self.output_dim = output_dim
|
||||
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
|
||||
x = self.encoder_embed(x)
|
||||
x = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
Loading…
x
Reference in New Issue
Block a user