Add results

This commit is contained in:
pkufool 2023-11-21 17:39:01 +08:00
parent e9ccc0b073
commit 019fa0396e
4 changed files with 171 additions and 40 deletions

View File

@ -0,0 +1,6 @@
# Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context
Libriheavy is a labeled version of [Librilight](https://arxiv.org/pdf/1912.07875.pdf). Please refer to our repository [k2-fsa/libriheavy](https://github.com/k2-fsa/libriheavy) for more details. We also have a paper: *Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context*, [Preprint available on arxiv](https://arxiv.org/abs/2309.08105).
See [RESULTS](./RESULTS.md) for the results for icefall recipes.

View File

@ -0,0 +1,111 @@
# Results
## zipformer (zipformer + pruned stateless transducer)
See <https://github.com/k2-fsa/icefall/pull/1261> for more details.
[zipformer](./zipformer)
### Non-streaming
#### Training on normalized text, i.e. Upper case without punctuation
##### normal-scaled model, number of model parameters: 65805511, i.e., 65.81 M
You can find a pretrained model, training logs at:
<https://www.modelscope.cn/models/pkufool/icefall-asr-zipformer-libriheavy-20230926/summary>
Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set),
exp_small_subset(small set).
Results of models:
| training set | decoding method | librispeech clean | librispeech other | libriheavy clean | libriheavy other | comment |
|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------|
| small | greedy search | 4.19 | 9.99 | 4.75 | 10.25 |--epoch 90 --avg 20 |
| small | modified beam search| 4.05 | 9.89 | 4.68 | 10.01 |--epoch 90 --avg 20 |
| medium | greedy search | 2.39 | 4.85 | 2.90 | 6.6 |--epoch 60 --avg 20 |
| medium | modified beam search| 2.35 | 4.82 | 2.90 | 6.57 |--epoch 60 --avg 20 |
| large | greedy search | 1.67 | 3.32 | 2.24 | 5.61 |--epoch 16 --avg 3 |
| large | modified beam search| 1.62 | 3.36 | 2.20 | 5.57 |--epoch 16 --avg 3 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python ./zipformer/train.py \
--world-size 4 \
--master-port 12365 \
--exp-dir zipformer/exp \
--num-epochs 60 \ # 16 for large; 90 for small
--lr-hours 15000 \ # 20000 for large; 5000 for small
--use-fp16 1 \
--start-epoch 1 \
--bpe-model data/lang_bpe_500/bpe.model \
--max-duration 1000 \
--subset medium
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search; do
./zipformer/decode.py \
--epoch 16 \
--avg 3 \
--exp-dir zipformer/exp \
--max-duration 1000 \
--causal 0 \
--decoding-method $m
done
```
#### Training on full formatted text, i.e. with casing and punctuation
##### normal-scaled model, number of model parameters: 66074067 , i.e., 66M
You can find a pretrained model, training logs at:
<https://www.modelscope.cn/models/pkufool/icefall-asr-zipformer-libriheavy-punc-20230830/summary>
Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set),
exp_small_subset(small set).
Results of models:
| training set | decoding method | libriheavy clean (WER) | libriheavy other (WER) | libriheavy clean (CER) | libriheavy other (CER) | comment |
|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------|
| small | modified beam search| 13.04 | 19.54 | 4.51 | 7.90 |--epoch 88 --avg 41 |
| medium | modified beam search| 9.84 | 13.39 | 3.02 | 5.10 |--epoch 50 --avg 15 |
| large | modified beam search| 7.76 | 11.32 | 2.41 | 4.22 |--epoch 16 --avg 2 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python ./zipformer/train.py \
--world-size 4 \
--master-port 12365 \
--exp-dir zipformer/exp \
--num-epochs 60 \ # 16 for large; 90 for small
--lr-hours 15000 \ # 20000 for large; 10000 for small
--use-fp16 1 \
--train-with-punctuation 1 \
--start-epoch 1 \
--bpe-model data/lang_punc_bpe_756/bpe.model \
--max-duration 1000 \
--subset medium
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search; do
./zipformer/decode.py \
--epoch 16 \
--avg 3 \
--exp-dir zipformer/exp \
--max-duration 1000 \
--causal 0 \
--decoding-method $m
done
```

View File

@ -208,6 +208,7 @@ fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Combine features for medium and large subsets."
for subset in medium large; do
log "Combining $subset subset."
if [ ! -f $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
pieces=$(find $fbank_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz")
lhotse combine $pieces $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz
@ -264,3 +265,50 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
fi
done
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare language model for normalized text"
for subset in small medium large; do
if [ ! -f $manifests_dir/texts_${subset} ]; then
gunzip -c $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \
| jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
| ./local/norm_text.py > $manifests_dir/texts_${subset}
fi
done
mkdir -p data/lm
if [ ! -f data/lm/text ]; then
cat $manifests_dir/texts_small $manifests_dir/texts_medium $manifests_dir/texts_large > data/lm/text
fi
(echo '<eps> 0'; echo '!SIL 1'; echo '<SPOKEN_NOISE> 2'; echo '<UNK> 3';) \
> data/lm/words.txt
cat data/lm/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
| awk '{print $1" "NR+3}' >> data/lm/words.txt
num_lines=$(< data/lm/words.txt wc -l)
(echo "#0 $num_lines"; echo "<s> $(($num_lines + 1))"; echo "</s> $(($num_lines + 2))";) \
>> data/lm/words.txt
# Train LM on transcripts
if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
python3 ./shared/make_kn_lm.py \
-ngram-order 3 \
-text data/lm/text \
-lm data/lm/3-gram.unpruned.arpa
fi
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table=data/lm/words.txt \
--disambig-symbol='#0' \
--max-order=3 \
data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
fi
fi

View File

@ -81,17 +81,6 @@ Usage:
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./zipformer/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./zipformer/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
@ -111,7 +100,6 @@ from asr_datamodule import LibriHeavyAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
@ -237,7 +225,7 @@ def get_parser():
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search_nbest,
and fast_beam_search_nbest_oracle
""",
)
@ -247,7 +235,7 @@ def get_parser():
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search, fast_beam_search_nbest,
and fast_beam_search_nbest_oracle""",
)
@ -256,7 +244,7 @@ def get_parser():
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
fast_beam_search, fast_beam_search_nbest,
and fast_beam_search_nbest_oracle""",
)
@ -280,7 +268,7 @@ def get_parser():
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -289,7 +277,7 @@ def get_parser():
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -318,7 +306,6 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -342,12 +329,10 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
fast_beam_search_nbest_oracle.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -388,20 +373,6 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
@ -493,7 +464,6 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -507,8 +477,6 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
@ -548,7 +516,6 @@ def decode_dataset(
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
@ -815,7 +782,6 @@ def main():
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)