mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 18:54:18 +00:00
Update RESULTS using vocab size 500, att rate 0.8
This commit is contained in:
parent
42b437bea6
commit
1e4920410f
@ -1,6 +1,61 @@
|
||||
## Results
|
||||
|
||||
### LibriSpeech BPE training results (Conformer-CTC)
|
||||
|
||||
#### 2021-11-09
|
||||
|
||||
The best WER, as of 2021-11-09, for the librispeech test dataset is below
|
||||
(using HLG decoding + n-gram LM rescoring + attention decoder rescoring):
|
||||
|
||||
| | test-clean | test-other |
|
||||
|-----|------------|------------|
|
||||
| WER | 2.42 | 5.73 |
|
||||
|
||||
Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are:
|
||||
| ngram_lm_scale | attention_scale |
|
||||
|----------------|-----------------|
|
||||
| 2.0 | 2.0 |
|
||||
|
||||
|
||||
To reproduce the above result, use the following commands for training:
|
||||
|
||||
```
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
./prepare.sh
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./conformer_ctc/train.py \
|
||||
--exp-dir conformer_ctc/exp_500_att0.8 \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--att-rate 0.8 \
|
||||
--full-libri 1 \
|
||||
--max-duration 200 \
|
||||
--concatenate-cuts 0 \
|
||||
--world-size 4 \
|
||||
--bucketing-sampler 1 \
|
||||
--start-epoch 0 \
|
||||
--num-epochs 80
|
||||
```
|
||||
|
||||
and the following command for decoding
|
||||
|
||||
```
|
||||
./conformer_ctc/decode.py \
|
||||
--exp-dir conformer_ctc/exp_500_att0.8 \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--max-duration 30 \
|
||||
--concatenate-cuts 0 \
|
||||
--bucketing-sampler 1 \
|
||||
--num-paths 1000 \
|
||||
--epoch 77 \
|
||||
--avg 55 \
|
||||
--method attention-decoder \
|
||||
--nbest-scale 0.5
|
||||
```
|
||||
|
||||
You can find the pre-trained model by visiting
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09>
|
||||
|
||||
|
||||
#### 2021-08-19
|
||||
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/13
|
||||
|
||||
|
@ -601,6 +601,11 @@ def main():
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
# Save a dummy value so that it can be loaded in C++.
|
||||
# See https://github.com/pytorch/pytorch/issues/67902
|
||||
# for why we need to do this.
|
||||
G["dummy"] = 1
|
||||
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
|
@ -81,7 +81,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=35,
|
||||
default=78,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
@ -108,13 +108,22 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_5000",
|
||||
default="data/lang_bpe_500",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--att-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="""The attention rate.
|
||||
The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -198,7 +207,6 @@ def get_params() -> AttributeDict:
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
"att_rate": 0.7,
|
||||
# parameters for Noam
|
||||
"weight_decay": 1e-6,
|
||||
"lr_factor": 5.0,
|
||||
|
@ -311,7 +311,7 @@ class Transformer(nn.Module):
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
token_ids: List[torch.Tensor],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> torch.Tensor:
|
||||
@ -334,6 +334,11 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
# The common part between this function and decoder_forward could be
|
||||
# extracted as a separate function.
|
||||
if isinstance(token_ids[0], torch.Tensor):
|
||||
# This branch is executed by torchscript in C++.
|
||||
# See https://github.com/k2-fsa/k2/pull/870
|
||||
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
|
||||
token_ids = [tolist(t) for t in token_ids]
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
@ -660,7 +665,7 @@ class PositionalEncoding(nn.Module):
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
# not doing: self.pe = None because of errors thrown by torchscript
|
||||
self.pe = torch.zeros(0, 0, dtype=torch.float32)
|
||||
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
|
||||
|
||||
def extend_pe(self, x: torch.Tensor) -> None:
|
||||
"""Extend the time t in the positional encoding if required.
|
||||
@ -1000,3 +1005,8 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||
with EOS ID.
|
||||
"""
|
||||
return [utt + [eos_id] for utt in token_ids]
|
||||
|
||||
|
||||
def tolist(t: torch.Tensor) -> List[int]:
|
||||
"""Used by jit"""
|
||||
return torch.jit.annotate(List[int], t.tolist())
|
||||
|
@ -224,6 +224,7 @@ class Nbest(object):
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
|
||||
# Each utterance has `num_paths` paths but some of them transduces
|
||||
# to the same word sequence, so we need to remove repeated word
|
||||
@ -363,7 +364,7 @@ class Nbest(object):
|
||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
"""
|
||||
saved_scores = self.fsa.scores
|
||||
saved_scores = self.fsa.scores.clone()
|
||||
|
||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||
self.fsa.scores = self.fsa.scores - self.fsa.lm_scores
|
||||
@ -390,10 +391,10 @@ class Nbest(object):
|
||||
Return a ragged tensor with 2 axes [utt][path_scores].
|
||||
Its dtype is torch.float64.
|
||||
"""
|
||||
saved_scores = self.fsa.scores
|
||||
saved_scores = self.fsa.scores.clone()
|
||||
|
||||
# The `scores` of every arc consists of `am_scores` and `lm_scores`
|
||||
self.fsa.scores = self.fsa.lm_scores
|
||||
self.fsa.scores = self.fsa.lm_scores.clone()
|
||||
|
||||
lm_scores = self.fsa.get_tot_scores(
|
||||
use_double_scores=True, log_semiring=False
|
||||
@ -870,6 +871,7 @@ def rescore_with_attention_decoder(
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
@ -877,6 +879,7 @@ def rescore_with_attention_decoder(
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user