add transducer_stateless with char unit to AIShell

This commit is contained in:
PingFeng Luo 2021-12-29 15:34:50 +08:00
parent 234307f33a
commit 18cdea4745
21 changed files with 275 additions and 163 deletions

View File

@ -1,5 +1,37 @@
## Results
### Aishell training results (Transducer-stateless)
#### 2021-12-29
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/sPEDmAQ3QcWuDAWGiKprVg/>
||test|
|--|--|
|CER| 5.7% |
You can use the following commands to reproduce our results:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
./transducer_stateless/train.py \
--bucketing-sampler True \
--world-size 8 \
--lang-dir data/lang_char \
--num-epochs 40 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp_char \
--max-duration 160 \
--lr-factor 3
./transducer_stateless/decode.py \
--epoch 39 \
--avg 10 \
--lang-dir data/lang_char \
--exp-dir transducer_stateless/exp_char \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
```
### Aishell training results (Conformer-MMI)
#### 2021-12-04
(Pingfeng Luo): Result of <https://github.com/k2-fsa/icefall/pull/140>

View File

@ -538,9 +538,13 @@ def main():
logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"]
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -22,8 +22,8 @@ import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
assert idim >= 7
super().__init__()
@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
Args:
x:
Its shape is [N, T, idim].
Its shape is (N, T, idim).
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x
@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
@ -93,10 +93,10 @@ class VggSubsampling(nn.Module):
Args:
idim:
Input dim. The input shape is [N, T, idim].
Input dim. The input shape is (N, T, idim).
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
"""
super().__init__()
@ -149,10 +149,10 @@ class VggSubsampling(nn.Module):
Args:
x:
Its shape is [N, T, idim].
Its shape is (N, T, idim).
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
"""
x = x.unsqueeze(1)
x = self.layers(x)

View File

@ -614,8 +614,8 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"])
aishell = AishellAsrDataModule(args)
train_dl = aishell.train_dataloaders()
valid_dl = aishell.valid_dataloaders()
train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)

View File

@ -557,9 +557,13 @@ def main():
logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["test"]
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -608,8 +608,9 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"])
aishell = AishellAsrDataModule(args)
train_dl = aishell.train_dataloaders()
valid_dl = aishell.valid_dataloaders()
train_cuts = aishell.train_cuts()
train_dl = aishell.train_dataloaders(train_cuts)
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)

View File

@ -0,0 +1,68 @@
#!/usr/bin/env python3
# Copyright 2021 (Author: Pingfeng Luo)
"""
make syllables lexicon and handle heteronym
"""
import argparse
from pathlib import Path
from pypinyin import pinyin, lazy_pinyin, Style
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
return parser.parse_args()
def process_line(
line: str
) -> None:
"""
Args:
line:
A line of transcript consisting of space(s) separated word and phones
input :
你好 n i3 h ao3
晴天 q ing2 t ian1
output :
你好 ni3 hao3
晴天 qing2 tian1
Returns:
Return None.
"""
chars = line.strip().split()[0]
pinyins = pinyin(chars, style=Style.TONE3, heteronym=True)
word_syllables = []
word_syllables_num = 1
inited = False
for char_syllables in pinyins :
new_char_syllables_num = len(char_syllables)
if not inited and len(char_syllables) :
word_syllables = [char_syllables[0]]
inited = True
elif new_char_syllables_num == 1 :
for i in range(word_syllables_num) :
word_syllables[i] += " " + str(char_syllables)
elif new_char_syllables_num > 1 :
word_syllables = word_syllables * new_char_syllables_num
for pre_index in range(word_syllables_num) :
for expand_index in range(new_char_syllables_num) :
word_syllables[pre_index * new_char_syllables_num + expand_index] += " " + char_syllables[expand_index]
word_syllables_num *= new_char_syllables_num
for word_syallable in word_syllables :
print("{} {}".format(chars.strip(), str(word_syallable).strip()))
def main():
args = get_args()
assert Path(args.lexicon).is_file()
with open(args.lexicon) as f:
for line in f:
process_line(line=line)
if __name__ == "__main__":
main()

View File

@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
@ -314,8 +315,14 @@ def lexicon_to_fst(
return fsa
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone or data/lang_syllable")
return parser.parse_args()
def main():
out_dir = Path("data/lang_phone")
out_dir = Path(get_args().lang_dir)
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5

View File

@ -124,7 +124,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
./local/prepare_lang.py
./local/prepare_lang.py --lang-dir $lang_phone_dir
fi
# Train a bigram P for MMI training
@ -133,7 +133,8 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid
find data/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text |
cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
fi
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then

View File

@ -318,7 +318,7 @@ class AishellAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
is_list = isinstance(cuts, list)
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
@ -328,40 +328,27 @@ class AishellAsrDataModule:
sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
if is_list:
return test_dl
else:
return test_dl[0]
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
cuts_train = load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz")
return cuts_valid
return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts
logging.info("About to get test cuts")
return load_manifest(self.args.manifest_dir / f"cuts_test.json.gz")

View File

@ -373,7 +373,9 @@ def main():
# if test_set == 'test-clean': continue
#
test_sets = ["test"]
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
test_dls = [test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,

View File

@ -553,8 +553,8 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"])
aishell = AishellAsrDataModule(args)
train_dl = aishell.train_dataloaders()
valid_dl = aishell.valid_dataloaders()
train_dl = aishell.train_dataloaders(aishell.train_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)

View File

@ -22,13 +22,18 @@ import torch
from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
Returns:
Return the decoded result.
"""
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# Maximum symbols per utterance.
max_sym_per_utt = 1000
# If at frame t, it decodes more than this number of symbols,
# it will move to the next step t+1
max_sym_per_frame = 3
# symbols per frame
sym_per_frame = 0
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt += 1
sym_per_frame += 1
if y == blank_id or sym_per_frame > max_sym_per_frame:
else:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks

View File

@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -15,26 +15,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
"""
import argparse
import logging
@ -42,18 +22,19 @@ from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import AishellAsrDataModule
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
@ -70,7 +51,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=20,
default=30,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
@ -91,10 +72,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
@ -114,6 +95,20 @@ def get_parser():
help="Used only when --decoding-method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser
@ -129,9 +124,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@ -149,7 +141,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -188,7 +179,7 @@ def get_transducer_model(params: AttributeDict):
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
lexicon: Lexicon,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -206,12 +197,12 @@ def decode_one_batch(
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -237,7 +228,11 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
@ -246,7 +241,7 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
hyps.append([lexicon.token_table[i] for i in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
@ -258,7 +253,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
lexicon: Lexicon,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -269,8 +264,6 @@ def decode_dataset(
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -297,7 +290,7 @@ def decode_dataset(
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
lexicon=lexicon,
batch=batch,
)
@ -332,16 +325,19 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
@ -353,11 +349,11 @@ def save_results(
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
@ -368,9 +364,10 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
@ -381,6 +378,9 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -391,12 +391,14 @@ def main():
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@ -422,23 +424,19 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test"]
test_dls = [test_dl]
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
lexicon=lexicon,
)
save_results(

View File

@ -20,13 +20,14 @@ import torch.nn.functional as F
class Decoder(nn.Module):
"""This class implements the stateless decoder from the following paper:
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network.
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""

View File

@ -104,6 +104,14 @@ def get_parser():
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder

View File

@ -16,7 +16,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
logit = torch.tanh(logit)
output = self.output_linear(logit)

View File

@ -110,6 +110,22 @@ def get_parser():
help="Used only when --method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
@ -126,9 +142,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -279,7 +291,11 @@ def main():
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size

View File

@ -2,6 +2,7 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
# Copyright 2021 (Pingfeng Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -120,6 +121,14 @@ def get_parser():
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -161,15 +170,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
@ -191,11 +195,7 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
@ -215,7 +215,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -556,11 +555,10 @@ def run(rank, world_size, args):
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
oov='<unk>',
)
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0]
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@ -584,7 +582,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints:
@ -611,8 +608,7 @@ def run(rank, world_size, args):
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = aishell.train_dataloaders(train_cuts)
valid_dl = aishell.valid_dataloaders(aishell.dev_cuts())
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)

View File

@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.output_dim = output_dim
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)