mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add transducer_stateless with char unit to AIShell
This commit is contained in:
parent
234307f33a
commit
18cdea4745
@ -1,5 +1,37 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
### Aishell training results (Transducer-stateless)
|
||||||
|
#### 2021-12-29
|
||||||
|
(Pingfeng Luo) : The tensorboard log for training is available at <https://tensorboard.dev/experiment/sPEDmAQ3QcWuDAWGiKprVg/>
|
||||||
|
|
||||||
|
||test|
|
||||||
|
|--|--|
|
||||||
|
|CER| 5.7% |
|
||||||
|
|
||||||
|
You can use the following commands to reproduce our results:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8"
|
||||||
|
./transducer_stateless/train.py \
|
||||||
|
--bucketing-sampler True \
|
||||||
|
--world-size 8 \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--num-epochs 40 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--exp-dir transducer_stateless/exp_char \
|
||||||
|
--max-duration 160 \
|
||||||
|
--lr-factor 3
|
||||||
|
|
||||||
|
./transducer_stateless/decode.py \
|
||||||
|
--epoch 39 \
|
||||||
|
--avg 10 \
|
||||||
|
--lang-dir data/lang_char \
|
||||||
|
--exp-dir transducer_stateless/exp_char \
|
||||||
|
--max-duration 100 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--beam-size 4
|
||||||
|
```
|
||||||
|
|
||||||
### Aishell training results (Conformer-MMI)
|
### Aishell training results (Conformer-MMI)
|
||||||
#### 2021-12-04
|
#### 2021-12-04
|
||||||
(Pingfeng Luo): Result of <https://github.com/k2-fsa/icefall/pull/140>
|
(Pingfeng Luo): Result of <https://github.com/k2-fsa/icefall/pull/140>
|
||||||
|
@ -538,9 +538,13 @@ def main():
|
|||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
|
test_cuts = aishell.test_cuts()
|
||||||
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_sets = ["test"]
|
test_sets = ["test"]
|
||||||
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
|
test_dls = [test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -22,8 +22,8 @@ import torch.nn as nn
|
|||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
"""Convolutional 2D subsampling (to 1/4 length).
|
"""Convolutional 2D subsampling (to 1/4 length).
|
||||||
|
|
||||||
Convert an input of shape [N, T, idim] to an output
|
Convert an input of shape (N, T, idim) to an output
|
||||||
with shape [N, T', odim], where
|
with shape (N, T', odim), where
|
||||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||||
|
|
||||||
It is based on
|
It is based on
|
||||||
@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
idim:
|
idim:
|
||||||
Input dim. The input shape is [N, T, idim].
|
Input dim. The input shape is (N, T, idim).
|
||||||
Caution: It requires: T >=7, idim >=7
|
Caution: It requires: T >=7, idim >=7
|
||||||
odim:
|
odim:
|
||||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
||||||
"""
|
"""
|
||||||
assert idim >= 7
|
assert idim >= 7
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
Its shape is [N, T, idim].
|
Its shape is (N, T, idim).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||||
"""
|
"""
|
||||||
# On entry, x is [N, T, idim]
|
# On entry, x is (N, T, idim)
|
||||||
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
|
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
|
|||||||
This paper is not 100% explicit so I am guessing to some extent,
|
This paper is not 100% explicit so I am guessing to some extent,
|
||||||
and trying to compare with other VGG implementations.
|
and trying to compare with other VGG implementations.
|
||||||
|
|
||||||
Convert an input of shape [N, T, idim] to an output
|
Convert an input of shape (N, T, idim) to an output
|
||||||
with shape [N, T', odim], where
|
with shape (N, T', odim), where
|
||||||
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -93,10 +93,10 @@ class VggSubsampling(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
idim:
|
idim:
|
||||||
Input dim. The input shape is [N, T, idim].
|
Input dim. The input shape is (N, T, idim).
|
||||||
Caution: It requires: T >=7, idim >=7
|
Caution: It requires: T >=7, idim >=7
|
||||||
odim:
|
odim:
|
||||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -149,10 +149,10 @@ class VggSubsampling(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
Its shape is [N, T, idim].
|
Its shape is (N, T, idim).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||||
"""
|
"""
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
|
@ -614,8 +614,8 @@ def run(rank, world_size, args):
|
|||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
train_dl = aishell.train_dataloaders()
|
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
||||||
valid_dl = aishell.valid_dataloaders()
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
@ -557,9 +557,13 @@ def main():
|
|||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
|
test_cuts = aishell.test_cuts()
|
||||||
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_sets = ["test"]
|
test_sets = ["test"]
|
||||||
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
|
test_dls = [test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -608,8 +608,9 @@ def run(rank, world_size, args):
|
|||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
train_dl = aishell.train_dataloaders()
|
train_cuts = aishell.train_cuts()
|
||||||
valid_dl = aishell.valid_dataloaders()
|
train_dl = aishell.train_dataloaders(train_cuts)
|
||||||
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
68
egs/aishell/ASR/local/make_syllable_lexicon.py
Executable file
68
egs/aishell/ASR/local/make_syllable_lexicon.py
Executable file
@ -0,0 +1,68 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2021 (Author: Pingfeng Luo)
|
||||||
|
"""
|
||||||
|
make syllables lexicon and handle heteronym
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from pypinyin import pinyin, lazy_pinyin, Style
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def process_line(
|
||||||
|
line: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
line:
|
||||||
|
A line of transcript consisting of space(s) separated word and phones
|
||||||
|
input :
|
||||||
|
你好 n i3 h ao3
|
||||||
|
晴天 q ing2 t ian1
|
||||||
|
|
||||||
|
output :
|
||||||
|
你好 ni3 hao3
|
||||||
|
晴天 qing2 tian1
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
chars = line.strip().split()[0]
|
||||||
|
pinyins = pinyin(chars, style=Style.TONE3, heteronym=True)
|
||||||
|
word_syllables = []
|
||||||
|
word_syllables_num = 1
|
||||||
|
inited = False
|
||||||
|
for char_syllables in pinyins :
|
||||||
|
new_char_syllables_num = len(char_syllables)
|
||||||
|
if not inited and len(char_syllables) :
|
||||||
|
word_syllables = [char_syllables[0]]
|
||||||
|
inited = True
|
||||||
|
elif new_char_syllables_num == 1 :
|
||||||
|
for i in range(word_syllables_num) :
|
||||||
|
word_syllables[i] += " " + str(char_syllables)
|
||||||
|
elif new_char_syllables_num > 1 :
|
||||||
|
word_syllables = word_syllables * new_char_syllables_num
|
||||||
|
for pre_index in range(word_syllables_num) :
|
||||||
|
for expand_index in range(new_char_syllables_num) :
|
||||||
|
word_syllables[pre_index * new_char_syllables_num + expand_index] += " " + char_syllables[expand_index]
|
||||||
|
word_syllables_num *= new_char_syllables_num
|
||||||
|
|
||||||
|
for word_syallable in word_syllables :
|
||||||
|
print("{} {}".format(chars.strip(), str(word_syallable).strip()))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert Path(args.lexicon).is_file()
|
||||||
|
|
||||||
|
with open(args.lexicon) as f:
|
||||||
|
for line in f:
|
||||||
|
process_line(line=line)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
|||||||
|
|
||||||
5. Generate L_disambig.pt, in k2 format.
|
5. Generate L_disambig.pt, in k2 format.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -314,8 +315,14 @@ def lexicon_to_fst(
|
|||||||
return fsa
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone or data/lang_syllable")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
out_dir = Path("data/lang_phone")
|
out_dir = Path(get_args().lang_dir)
|
||||||
lexicon_filename = out_dir / "lexicon.txt"
|
lexicon_filename = out_dir / "lexicon.txt"
|
||||||
sil_token = "SIL"
|
sil_token = "SIL"
|
||||||
sil_prob = 0.5
|
sil_prob = 0.5
|
||||||
|
@ -124,7 +124,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
|
./local/generate_unique_lexicon.py --lang-dir $lang_phone_dir
|
||||||
|
|
||||||
if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
|
if [ ! -f $lang_phone_dir/L_disambig.pt ]; then
|
||||||
./local/prepare_lang.py
|
./local/prepare_lang.py --lang-dir $lang_phone_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Train a bigram P for MMI training
|
# Train a bigram P for MMI training
|
||||||
@ -133,7 +133,8 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid
|
aishell_train_uid=$dl_dir/aishell/data_aishell/transcript/aishell_train_uid
|
||||||
find data/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid
|
find data/aishell/data_aishell/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_train_uid
|
||||||
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text | cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
|
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_train_uid $aishell_text |
|
||||||
|
cut -d " " -f 2- > $lang_phone_dir/transcript_words.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then
|
if [ ! -f $lang_phone_dir/transcript_tokens.txt ]; then
|
||||||
|
@ -318,7 +318,7 @@ class AishellAsrDataModule:
|
|||||||
return valid_dl
|
return valid_dl
|
||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
is_list = isinstance(cuts, list)
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
@ -328,40 +328,27 @@ class AishellAsrDataModule:
|
|||||||
sampler = BucketingSampler(
|
sampler = BucketingSampler(
|
||||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||||
)
|
)
|
||||||
logging.debug("About to create test dataloader")
|
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test,
|
test,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
if is_list:
|
|
||||||
return test_dl
|
|
||||||
else:
|
|
||||||
return test_dl[0]
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
|
cuts_train = load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
|
||||||
return cuts_train
|
return cuts_train
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz")
|
return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
|
||||||
return cuts_valid
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> List[CutSet]:
|
def test_cuts(self) -> List[CutSet]:
|
||||||
test_sets = ["test"]
|
logging.info("About to get test cuts")
|
||||||
cuts = []
|
return load_manifest(self.args.manifest_dir / f"cuts_test.json.gz")
|
||||||
for test_set in test_sets:
|
|
||||||
logging.debug("About to get test cuts")
|
|
||||||
cuts.append(
|
|
||||||
load_manifest(
|
|
||||||
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return cuts
|
|
||||||
|
@ -373,7 +373,9 @@ def main():
|
|||||||
# if test_set == 'test-clean': continue
|
# if test_set == 'test-clean': continue
|
||||||
#
|
#
|
||||||
test_sets = ["test"]
|
test_sets = ["test"]
|
||||||
for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()):
|
test_dls = [test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -553,8 +553,8 @@ def run(rank, world_size, args):
|
|||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
train_dl = aishell.train_dataloaders()
|
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
||||||
valid_dl = aishell.valid_dataloaders()
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
@ -22,13 +22,18 @@ import torch
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
def greedy_search(
|
||||||
|
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||||
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
|
max_sym_per_frame:
|
||||||
|
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||||
|
would be 100%.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
# Maximum symbols per utterance.
|
# Maximum symbols per utterance.
|
||||||
max_sym_per_utt = 1000
|
max_sym_per_utt = 1000
|
||||||
|
|
||||||
# If at frame t, it decodes more than this number of symbols,
|
|
||||||
# it will move to the next step t+1
|
|
||||||
max_sym_per_frame = 3
|
|
||||||
|
|
||||||
# symbols per frame
|
# symbols per frame
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
|
|
||||||
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
sym_per_utt = 0
|
sym_per_utt = 0
|
||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
while t < T and sym_per_utt < max_sym_per_utt:
|
||||||
|
if sym_per_frame >= max_sym_per_frame:
|
||||||
|
sym_per_frame = 0
|
||||||
|
t += 1
|
||||||
|
continue
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
|
|
||||||
sym_per_utt += 1
|
sym_per_utt += 1
|
||||||
sym_per_frame += 1
|
sym_per_frame += 1
|
||||||
|
else:
|
||||||
if y == blank_id or sym_per_frame > max_sym_per_frame:
|
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
t += 1
|
t += 1
|
||||||
hyp = hyp[context_size:] # remove blanks
|
hyp = hyp[context_size:] # remove blanks
|
||||||
|
@ -56,7 +56,6 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -69,7 +68,6 @@ class Conformer(Transformer):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
normalize_before=normalize_before,
|
normalize_before=normalize_before,
|
||||||
vgg_frontend=vgg_frontend,
|
vgg_frontend=vgg_frontend,
|
||||||
use_feat_batchnorm=use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
@ -107,11 +105,6 @@ class Conformer(Transformer):
|
|||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(channels)
|
self.norm = nn.LayerNorm(channels)
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.activation(self.norm(x))
|
# x is (batch, channels, time)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
@ -15,26 +15,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
(1) greedy search
|
|
||||||
./transducer_stateless/decode.py \
|
|
||||||
--epoch 14 \
|
|
||||||
--avg 7 \
|
|
||||||
--exp-dir ./transducer_stateless/exp \
|
|
||||||
--max-duration 100 \
|
|
||||||
--decoding-method greedy_search
|
|
||||||
|
|
||||||
(2) beam search
|
|
||||||
./transducer_stateless/decode.py \
|
|
||||||
--epoch 14 \
|
|
||||||
--avg 7 \
|
|
||||||
--exp-dir ./transducer_stateless/exp \
|
|
||||||
--max-duration 100 \
|
|
||||||
--decoding-method beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -42,18 +22,19 @@ from collections import defaultdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
@ -70,7 +51,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=30,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
@ -91,10 +72,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_char",
|
||||||
help="Path to the BPE model",
|
help="The lang dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -114,6 +95,20 @@ def get_parser():
|
|||||||
help="Used only when --decoding-method is beam_search",
|
help="Used only when --decoding-method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Maximum number of symbols per frame",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -129,9 +124,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -149,7 +141,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -188,7 +179,7 @@ def get_transducer_model(params: AttributeDict):
|
|||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
lexicon: Lexicon,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
@ -206,12 +197,12 @@ def decode_one_batch(
|
|||||||
It's the return value of :func:`get_params`.
|
It's the return value of :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
batch:
|
batch:
|
||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
lexicon:
|
||||||
|
It contains the token symbol table and the word symbol table.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -237,7 +228,11 @@ def decode_one_batch(
|
|||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
@ -246,7 +241,7 @@ def decode_one_batch(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append([lexicon.token_table[i] for i in hyp])
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
@ -258,7 +253,7 @@ def decode_dataset(
|
|||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
lexicon: Lexicon,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -269,8 +264,6 @@ def decode_dataset(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
|
||||||
The BPE model.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -297,7 +290,7 @@ def decode_dataset(
|
|||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
lexicon=lexicon,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -332,16 +325,19 @@ def save_results(
|
|||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
# we compute CER for aishell dataset.
|
||||||
|
results_char = []
|
||||||
|
for res in results:
|
||||||
|
results_char.append((list("".join(res[0])), list("".join(res[1]))))
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
f, f"{test_set_name}-{key}", results_char, enable_log=True
|
||||||
)
|
)
|
||||||
test_set_wers[key] = wer
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
@ -353,11 +349,11 @@ def save_results(
|
|||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
print("{}\t{}".format(key, val), file=f)
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
|
||||||
note = "\tbest for {}".format(test_set_name)
|
note = "\tbest for {}".format(test_set_name)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
@ -368,9 +364,10 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
@ -381,6 +378,9 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if params.decoding_method == "beam_search":
|
if params.decoding_method == "beam_search":
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -391,12 +391,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
lexicon = Lexicon(params.lang_dir)
|
||||||
sp.load(params.bpe_model)
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -422,23 +424,19 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
|
test_cuts = aishell.test_cuts()
|
||||||
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_sets = ["test"]
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
test_dls = [test_dl]
|
||||||
|
|
||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
lexicon=lexicon,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -20,13 +20,14 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""This class implements the stateless decoder from the following paper:
|
"""This class modifies the stateless decoder from the following paper:
|
||||||
|
|
||||||
RNN-transducer with stateless prediction network
|
RNN-transducer with stateless prediction network
|
||||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||||
|
|
||||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||||
network.
|
network. Different from the above paper, it adds an extra Conv1d
|
||||||
|
right after the embedding layer.
|
||||||
|
|
||||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||||
"""
|
"""
|
||||||
|
@ -104,6 +104,14 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
@ -48,7 +47,7 @@ class Joiner(nn.Module):
|
|||||||
# Now decoder_out is (N, 1, U, C)
|
# Now decoder_out is (N, 1, U, C)
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
logit = F.relu(logit)
|
logit = torch.tanh(logit)
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
output = self.output_linear(logit)
|
||||||
|
|
||||||
|
@ -110,6 +110,22 @@ def get_parser():
|
|||||||
help="Used only when --method is beam_search",
|
help="Used only when --method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Maximum number of symbols per frame. Used only when
|
||||||
|
--method is greedy_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -126,9 +142,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -279,7 +291,11 @@ def main():
|
|||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.method == "greedy_search":
|
if params.method == "greedy_search":
|
||||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
elif params.method == "beam_search":
|
elif params.method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang
|
# Wei Kang
|
||||||
# Mingshuang Luo)
|
# Mingshuang Luo)
|
||||||
|
# Copyright 2021 (Pingfeng Luo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -120,6 +121,14 @@ def get_parser():
|
|||||||
help="The lr_factor for Noam optimizer",
|
help="The lr_factor for Noam optimizer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -161,15 +170,10 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
|
||||||
input features.
|
|
||||||
|
|
||||||
- attention_dim: Hidden dim for multi-head attention model.
|
- attention_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
- weight_decay: The weight_decay for the optimizer.
|
|
||||||
|
|
||||||
- warm_step: The warm_step for Noam optimizer.
|
- warm_step: The warm_step for Noam optimizer.
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
@ -191,11 +195,7 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"weight_decay": 1e-6,
|
|
||||||
"warm_step": 80000, # For the 100h subset, use 8k
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -215,7 +215,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -556,11 +555,10 @@ def run(rank, world_size, args):
|
|||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
device=device,
|
device=device,
|
||||||
sos_token="<sos/eos>",
|
oov='<unk>',
|
||||||
eos_token="<sos/eos>",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0]
|
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -584,7 +582,6 @@ def run(rank, world_size, args):
|
|||||||
model_size=params.attention_dim,
|
model_size=params.attention_dim,
|
||||||
factor=params.lr_factor,
|
factor=params.lr_factor,
|
||||||
warm_step=params.warm_step,
|
warm_step=params.warm_step,
|
||||||
weight_decay=params.weight_decay,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
@ -611,8 +608,7 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||||
|
|
||||||
train_dl = aishell.train_dataloaders(train_cuts)
|
train_dl = aishell.train_dataloaders(train_cuts)
|
||||||
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||||
valid_dl = aishell.valid_dataloaders(aishell.dev_cuts())
|
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
|
|||||||
If True, use pre-layer norm; False to use post-layer norm.
|
If True, use pre-layer norm; False to use post-layer norm.
|
||||||
vgg_frontend:
|
vgg_frontend:
|
||||||
True to use vgg style frontend for subsampling.
|
True to use vgg style frontend for subsampling.
|
||||||
use_feat_batchnorm:
|
|
||||||
True to use batchnorm for the input layer.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_feat_batchnorm = use_feat_batchnorm
|
|
||||||
if use_feat_batchnorm:
|
|
||||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
|
|||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = self.encoder_pos(x)
|
x = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user