Update RESULTS using vocab size 500, att rate 0.8

This commit is contained in:
Fangjun Kuang 2021-11-09 20:46:11 +08:00
parent 42b437bea6
commit 1e4920410f
5 changed files with 89 additions and 8 deletions

View File

@ -1,6 +1,61 @@
## Results ## Results
### LibriSpeech BPE training results (Conformer-CTC) ### LibriSpeech BPE training results (Conformer-CTC)
#### 2021-11-09
The best WER, as of 2021-11-09, for the librispeech test dataset is below
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring):
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.42 | 5.73 |
Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
| ngram_lm_scale | attention_scale |
|----------------|-----------------|
| 2.0 | 2.0 |
To reproduce the above result, use the following commands for training:
```
cd egs/librispeech/ASR/conformer_ctc
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./conformer_ctc/train.py \
--exp-dir conformer_ctc/exp_500_att0.8 \
--lang-dir data/lang_bpe_500 \
--att-rate 0.8 \
--full-libri 1 \
--max-duration 200 \
--concatenate-cuts 0 \
--world-size 4 \
--bucketing-sampler 1 \
--start-epoch 0 \
--num-epochs 80
```
and the following command for decoding
```
./conformer_ctc/decode.py \
--exp-dir conformer_ctc/exp_500_att0.8 \
--lang-dir data/lang_bpe_500 \
--max-duration 30 \
--concatenate-cuts 0 \
--bucketing-sampler 1 \
--num-paths 1000 \
--epoch 77 \
--avg 55 \
--method attention-decoder \
--nbest-scale 0.5
```
You can find the pre-trained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09>
#### 2021-08-19 #### 2021-08-19
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/13 (Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/13

View File

@ -601,6 +601,11 @@ def main():
G.labels[G.labels >= first_word_disambig_id] = 0 G.labels[G.labels >= first_word_disambig_id] = 0
G = k2.Fsa.from_fsas([G]).to(device) G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G) G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G["dummy"] = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")

View File

@ -81,7 +81,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=35, default=78,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -108,13 +108,22 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_5000", default="data/lang_bpe_500",
help="""The lang dir help="""The lang dir
It contains language related input files such as It contains language related input files such as
"lexicon.txt" "lexicon.txt"
""", """,
) )
parser.add_argument(
"--att-rate",
type=float,
default=0.8,
help="""The attention rate.
The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
""",
)
return parser return parser
@ -198,7 +207,6 @@ def get_params() -> AttributeDict:
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
"att_rate": 0.7,
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6, "weight_decay": 1e-6,
"lr_factor": 5.0, "lr_factor": 5.0,

View File

@ -311,7 +311,7 @@ class Transformer(nn.Module):
self, self,
memory: torch.Tensor, memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]], token_ids: List[torch.Tensor],
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
) -> torch.Tensor: ) -> torch.Tensor:
@ -334,6 +334,11 @@ class Transformer(nn.Module):
""" """
# The common part between this function and decoder_forward could be # The common part between this function and decoder_forward could be
# extracted as a separate function. # extracted as a separate function.
if isinstance(token_ids[0], torch.Tensor):
# This branch is executed by torchscript in C++.
# See https://github.com/k2-fsa/k2/pull/870
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
token_ids = [tolist(t) for t in token_ids]
ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in] ys_in = [torch.tensor(y) for y in ys_in]
@ -660,7 +665,7 @@ class PositionalEncoding(nn.Module):
self.xscale = math.sqrt(self.d_model) self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
# not doing: self.pe = None because of errors thrown by torchscript # not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(0, 0, dtype=torch.float32) self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None: def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required. """Extend the time t in the positional encoding if required.
@ -1000,3 +1005,8 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
with EOS ID. with EOS ID.
""" """
return [utt + [eos_id] for utt in token_ids] return [utt + [eos_id] for utt in token_ids]
def tolist(t: torch.Tensor) -> List[int]:
"""Used by jit"""
return torch.jit.annotate(List[int], t.tolist())

View File

@ -224,6 +224,7 @@ class Nbest(object):
else: else:
word_seq = lattice.aux_labels.index(path) word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(word_seq.num_axes - 2) word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
word_seq = word_seq.remove_values_leq(0)
# Each utterance has `num_paths` paths but some of them transduces # Each utterance has `num_paths` paths but some of them transduces
# to the same word sequence, so we need to remove repeated word # to the same word sequence, so we need to remove repeated word
@ -363,7 +364,7 @@ class Nbest(object):
Return a ragged tensor with 2 axes [utt][path_scores]. Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64. Its dtype is torch.float64.
""" """
saved_scores = self.fsa.scores saved_scores = self.fsa.scores.clone()
# The `scores` of every arc consists of `am_scores` and `lm_scores` # The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
@ -390,10 +391,10 @@ class Nbest(object):
Return a ragged tensor with 2 axes [utt][path_scores]. Return a ragged tensor with 2 axes [utt][path_scores].
Its dtype is torch.float64. Its dtype is torch.float64.
""" """
saved_scores = self.fsa.scores saved_scores = self.fsa.scores.clone()
# The `scores` of every arc consists of `am_scores` and `lm_scores` # The `scores` of every arc consists of `am_scores` and `lm_scores`
self.fsa.scores = self.fsa.lm_scores self.fsa.scores = self.fsa.lm_scores.clone()
lm_scores = self.fsa.get_tot_scores( lm_scores = self.fsa.get_tot_scores(
use_double_scores=True, log_semiring=False use_double_scores=True, log_semiring=False
@ -870,6 +871,7 @@ def rescore_with_attention_decoder(
ngram_lm_scale_list = [0.01, 0.05, 0.08] ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
ngram_lm_scale_list = [ngram_lm_scale] ngram_lm_scale_list = [ngram_lm_scale]
@ -877,6 +879,7 @@ def rescore_with_attention_decoder(
attention_scale_list = [0.01, 0.05, 0.08] attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
attention_scale_list = [attention_scale] attention_scale_list = [attention_scale]