this commit finalize the recipe (hopefully)

This commit is contained in:
jinzr 2023-09-02 01:33:26 +08:00
parent aadb4507bf
commit 2cb0092b09
9 changed files with 107 additions and 171 deletions

View File

@ -0,0 +1,23 @@
# Introduction
This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
# Included Training Sets
1. THCHS-30
2. AiShell-{1,2,4}
3. ST-CMDS
4. Primewords
5. MagicData
6. Aidatatang_200zh
7. AliMeeting
8. WeNetSpeech
9. KeSpeech-ASR
# Included Test Sets
1. Aishell-{1,2,4}
2. Aidatatang_200zh
3. AliMeeting
4. MagicData
5. KeSpeech-ASR
6. WeNetSpeech

View File

@ -0,0 +1,30 @@
## Results
### WenetSpeech char-based training results (Non-streaming and streaming) on zipformer model
This is the [pull request](https://github.com/k2-fsa/icefall/pull/1130) in icefall.
#### Non-streaming
Best results (num of params : ~68M):
The training command:
```
./zipformer/train.py \
--world-size 4 \
--num-epochs 23 \
--use-fp16 1 \
--max-duration 500 \
--num-workers 8
```
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
The pre-trained model is available here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2

View File

@ -229,10 +229,12 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
cd data/fbank
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_M.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_S.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_1000) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/*.lca) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/) ./wenetspeech
cd ../..
else
log "Abort! Please run ../../wenetspeech/ASR/prepare.sh"

View File

@ -52,7 +52,7 @@ class _SeedWorkers:
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule:
class AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
@ -82,20 +82,6 @@ class LibriSpeechAsrDataModule:
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="""Used only when --mini-libri is False.When enabled,
use 960h LibriSpeech. Otherwise, use 100h subset.""",
)
group.add_argument(
"--mini-libri",
type=str2bool,
default=False,
help="True for mini librispeech",
)
group.add_argument(
"--manifest-dir",
type=Path,
@ -400,76 +386,3 @@ class LibriSpeechAsrDataModule:
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_clean_5_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_2_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)

View File

@ -87,8 +87,8 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_model
from asr_datamodule import AsrDataModule
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
@ -598,7 +598,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
@ -811,7 +811,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
librispeech = AsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()

View File

@ -105,7 +105,7 @@ import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@ -609,7 +609,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -784,15 +784,9 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
data_module = AsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
# test_clean_cuts = librispeech.test_clean_cuts()
# test_other_cuts = librispeech.test_other_cuts()
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
@ -805,7 +799,7 @@ def main():
test_sets = test_sets_cuts.keys()
test_dl = [
librispeech.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
for cuts_name in test_sets
]

View File

@ -76,12 +76,11 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import greedy_search, OnnxModel
from asr_datamodule import AsrDataModule
from k2 import SymbolTable
from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
from k2 import SymbolTable
def get_parser():
@ -263,7 +262,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
@ -290,7 +289,7 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
librispeech = AsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
@ -303,7 +302,9 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
results, total_duration = decode_dataset(
dl=test_dl, model=model, token_table=token_table
)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration

View File

@ -40,7 +40,7 @@ import k2
import numpy as np
import sentencepiece as spm
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import AsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
@ -51,7 +51,7 @@ from streaming_beam_search import (
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat(
[state_list[i][-1] for i in range(batch_size)], dim=0
)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(
chunks=batch_size, dim=1
)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(
chunks=batch_size, dim=0
)
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
@ -404,9 +398,7 @@ def streaming_forward(
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
@ -494,9 +486,7 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(
model=model, encoder_out=encoder_out, streams=decode_streams
)
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
@ -517,9 +507,7 @@ def decode_one_chunk(
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
@ -577,9 +565,7 @@ def decode_dataset(
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(
model=model, batch_size=1, device=device
)
initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
@ -649,9 +635,7 @@ def decode_dataset(
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
@ -684,8 +668,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -703,7 +686,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -718,9 +701,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
@ -760,9 +741,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -789,9 +770,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -846,7 +827,7 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
librispeech = AsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()

View File

@ -30,7 +30,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--full-libri 1 \
--max-duration 1000
# For streaming model training:
@ -41,7 +40,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--use-fp16 1 \
--exp-dir zipformer/exp \
--causal 1 \
--full-libri 1 \
--max-duration 1000
It supports training with:
@ -65,7 +63,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import AsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -1173,14 +1171,10 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args)
data_module = AsrDataModule(args)
multi_dataset = MultiDataset(args.manifest_dir)
train_cuts = multi_dataset.train_cuts()
# train_cuts = librispeech.train_clean_100_cuts()
# if params.full_libri:
# train_cuts += librispeech.train_clean_360_cuts()
# train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1228,23 +1222,21 @@ def run(rank, world_size, args):
else:
sampler_state_dict = None
train_dl = librispeech.train_dataloaders(
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
# valid_cuts = librispeech.dev_clean_cuts()
# valid_cuts += librispeech.dev_other_cuts()
valid_cuts = multi_dataset.dev_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
valid_dl = data_module.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
@ -1374,7 +1366,7 @@ def scan_pessimistic_batches_for_oom(
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)