mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
this commit finalize the recipe (hopefully)
This commit is contained in:
parent
aadb4507bf
commit
2cb0092b09
23
egs/multi_zh-hans/ASR/README.md
Normal file
23
egs/multi_zh-hans/ASR/README.md
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
|
||||||
|
# Introduction
|
||||||
|
|
||||||
|
This recipe includes scripts for training Zipformer model using multiple Chinese datasets.
|
||||||
|
|
||||||
|
# Included Training Sets
|
||||||
|
1. THCHS-30
|
||||||
|
2. AiShell-{1,2,4}
|
||||||
|
3. ST-CMDS
|
||||||
|
4. Primewords
|
||||||
|
5. MagicData
|
||||||
|
6. Aidatatang_200zh
|
||||||
|
7. AliMeeting
|
||||||
|
8. WeNetSpeech
|
||||||
|
9. KeSpeech-ASR
|
||||||
|
|
||||||
|
# Included Test Sets
|
||||||
|
1. Aishell-{1,2,4}
|
||||||
|
2. Aidatatang_200zh
|
||||||
|
3. AliMeeting
|
||||||
|
4. MagicData
|
||||||
|
5. KeSpeech-ASR
|
||||||
|
6. WeNetSpeech
|
30
egs/multi_zh-hans/ASR/RESULTS.md
Normal file
30
egs/multi_zh-hans/ASR/RESULTS.md
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
## Results
|
||||||
|
|
||||||
|
### WenetSpeech char-based training results (Non-streaming and streaming) on zipformer model
|
||||||
|
|
||||||
|
This is the [pull request](https://github.com/k2-fsa/icefall/pull/1130) in icefall.
|
||||||
|
|
||||||
|
#### Non-streaming
|
||||||
|
|
||||||
|
Best results (num of params : ~68M):
|
||||||
|
|
||||||
|
The training command:
|
||||||
|
|
||||||
|
```
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 23 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--max-duration 500 \
|
||||||
|
--num-workers 8
|
||||||
|
```
|
||||||
|
|
||||||
|
Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model ( # tokens is 2000, byte fallback enabled).
|
||||||
|
|
||||||
|
| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|
||||||
|
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
|
||||||
|
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
|
||||||
|
| | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |
|
||||||
|
|
||||||
|
|
||||||
|
The pre-trained model is available here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2
|
@ -229,10 +229,12 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
|||||||
cd data/fbank
|
cd data/fbank
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_M.jsonl.gz) .
|
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_S.jsonl.gz) .
|
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
|
||||||
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
|
||||||
|
|
||||||
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/L_split_1000) .
|
||||||
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/*.lca) .
|
||||||
|
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/) ./wenetspeech
|
||||||
cd ../..
|
cd ../..
|
||||||
else
|
else
|
||||||
log "Abort! Please run ../../wenetspeech/ASR/prepare.sh"
|
log "Abort! Please run ../../wenetspeech/ASR/prepare.sh"
|
||||||
|
@ -52,7 +52,7 @@ class _SeedWorkers:
|
|||||||
fix_random_seed(self.seed + worker_id)
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
class LibriSpeechAsrDataModule:
|
class AsrDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for k2 ASR experiments.
|
DataModule for k2 ASR experiments.
|
||||||
It assumes there is always one train and valid dataloader,
|
It assumes there is always one train and valid dataloader,
|
||||||
@ -82,20 +82,6 @@ class LibriSpeechAsrDataModule:
|
|||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc.",
|
"augmentations, etc.",
|
||||||
)
|
)
|
||||||
group.add_argument(
|
|
||||||
"--full-libri",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="""Used only when --mini-libri is False.When enabled,
|
|
||||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--mini-libri",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="True for mini librispeech",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
@ -400,76 +386,3 @@ class LibriSpeechAsrDataModule:
|
|||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_clean_5_cuts(self) -> CutSet:
|
|
||||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_clean_100_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get train-clean-100 cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_clean_360_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get train-clean-360 cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_other_500_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get train-other-500 cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_all_shuf_cuts(self) -> CutSet:
|
|
||||||
logging.info(
|
|
||||||
"About to get the shuffled train-clean-100, \
|
|
||||||
train-clean-360 and train-other-500 cuts"
|
|
||||||
)
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def dev_clean_2_cuts(self) -> CutSet:
|
|
||||||
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def dev_clean_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get dev-clean cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def dev_other_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get dev-other cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_clean_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test-clean cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_other_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test-other cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
@ -87,8 +87,8 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -598,7 +598,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
args.lang_dir = Path(args.lang_dir)
|
||||||
@ -811,7 +811,7 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = AsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
@ -105,7 +105,7 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
@ -609,7 +609,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -784,15 +784,9 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
data_module = AsrDataModule(args)
|
||||||
multi_dataset = MultiDataset(args.manifest_dir)
|
multi_dataset = MultiDataset(args.manifest_dir)
|
||||||
|
|
||||||
# test_clean_cuts = librispeech.test_clean_cuts()
|
|
||||||
# test_other_cuts = librispeech.test_other_cuts()
|
|
||||||
|
|
||||||
# test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
|
||||||
# test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
if T <= 0:
|
if T <= 0:
|
||||||
@ -805,7 +799,7 @@ def main():
|
|||||||
|
|
||||||
test_sets = test_sets_cuts.keys()
|
test_sets = test_sets_cuts.keys()
|
||||||
test_dl = [
|
test_dl = [
|
||||||
librispeech.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
|
||||||
for cuts_name in test_sets
|
for cuts_name in test_sets
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -76,12 +76,11 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
|
from k2 import SymbolTable
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
from k2 import SymbolTable
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -263,7 +262,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -290,7 +289,7 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = AsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
@ -303,7 +302,9 @@ def main():
|
|||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
|
results, total_duration = decode_dataset(
|
||||||
|
dl=test_dl, model=model, token_table=token_table
|
||||||
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_seconds = end_time - start_time
|
elapsed_seconds = end_time - start_time
|
||||||
rtf = elapsed_seconds / total_duration
|
rtf = elapsed_seconds / total_duration
|
||||||
|
@ -40,7 +40,7 @@ import k2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from decode_stream import DecodeStream
|
from decode_stream import DecodeStream
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
|||||||
)
|
)
|
||||||
batch_states.append(cached_embed_left_pad)
|
batch_states.append(cached_embed_left_pad)
|
||||||
|
|
||||||
processed_lens = torch.cat(
|
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
||||||
[state_list[i][-1] for i in range(batch_size)], dim=0
|
|
||||||
)
|
|
||||||
batch_states.append(processed_lens)
|
batch_states.append(processed_lens)
|
||||||
|
|
||||||
return batch_states
|
return batch_states
|
||||||
@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
|||||||
for layer in range(tot_num_layers):
|
for layer in range(tot_num_layers):
|
||||||
layer_offset = layer * 6
|
layer_offset = layer * 6
|
||||||
# cached_key: (left_context_len, batch_size, key_dim)
|
# cached_key: (left_context_len, batch_size, key_dim)
|
||||||
cached_key_list = batch_states[layer_offset].chunk(
|
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
||||||
chunks=batch_size, dim=1
|
|
||||||
)
|
|
||||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||||
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
||||||
chunks=batch_size, dim=1
|
chunks=batch_size, dim=1
|
||||||
@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
|||||||
cached_conv2_list[i],
|
cached_conv2_list[i],
|
||||||
]
|
]
|
||||||
|
|
||||||
cached_embed_left_pad_list = batch_states[-2].chunk(
|
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
||||||
chunks=batch_size, dim=0
|
|
||||||
)
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
state_list[i].append(cached_embed_left_pad_list[i])
|
state_list[i].append(cached_embed_left_pad_list[i])
|
||||||
|
|
||||||
@ -404,9 +398,7 @@ def streaming_forward(
|
|||||||
new_processed_lens = processed_lens + x_lens
|
new_processed_lens = processed_lens + x_lens
|
||||||
|
|
||||||
# (batch, left_context_size + chunk_size)
|
# (batch, left_context_size + chunk_size)
|
||||||
src_key_padding_mask = torch.cat(
|
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||||
[processed_mask, src_key_padding_mask], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
encoder_states = states[:-2]
|
encoder_states = states[:-2]
|
||||||
@ -494,9 +486,7 @@ def decode_one_chunk(
|
|||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
greedy_search(
|
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||||
model=model, encoder_out=encoder_out, streams=decode_streams
|
|
||||||
)
|
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
processed_lens = torch.tensor(processed_lens, device=device)
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
processed_lens = processed_lens + encoder_out_lens
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
@ -517,9 +507,7 @@ def decode_one_chunk(
|
|||||||
num_active_paths=params.num_active_paths,
|
num_active_paths=params.num_active_paths,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
|
||||||
)
|
|
||||||
|
|
||||||
states = unstack_states(new_states)
|
states = unstack_states(new_states)
|
||||||
|
|
||||||
@ -577,9 +565,7 @@ def decode_dataset(
|
|||||||
decode_streams = []
|
decode_streams = []
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
initial_states = get_init_states(
|
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
||||||
model=model, batch_size=1, device=device
|
|
||||||
)
|
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
cut_id=cut.id,
|
cut_id=cut.id,
|
||||||
@ -649,9 +635,7 @@ def decode_dataset(
|
|||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
key = f"num_active_paths_{params.num_active_paths}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
|
||||||
)
|
|
||||||
return {key: decode_results}
|
return {key: decode_results}
|
||||||
|
|
||||||
|
|
||||||
@ -684,8 +668,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -703,7 +686,7 @@ def save_results(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -718,9 +701,7 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
assert params.causal, params.causal
|
assert params.causal, params.causal
|
||||||
assert (
|
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||||
"," not in params.chunk_size
|
|
||||||
), "chunk_size should be one value in decoding."
|
|
||||||
assert (
|
assert (
|
||||||
"," not in params.left_context_frames
|
"," not in params.left_context_frames
|
||||||
), "left_context_frames should be one value in decoding."
|
), "left_context_frames should be one value in decoding."
|
||||||
@ -760,9 +741,9 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg
|
||||||
)[: params.avg]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -789,9 +770,9 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -846,7 +827,7 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = AsrDataModule(args)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
|
@ -30,7 +30,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 1000
|
--max-duration 1000
|
||||||
|
|
||||||
# For streaming model training:
|
# For streaming model training:
|
||||||
@ -41,7 +40,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
--causal 1 \
|
--causal 1 \
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 1000
|
--max-duration 1000
|
||||||
|
|
||||||
It supports training with:
|
It supports training with:
|
||||||
@ -65,7 +63,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -1173,14 +1171,10 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
data_module = AsrDataModule(args)
|
||||||
multi_dataset = MultiDataset(args.manifest_dir)
|
multi_dataset = MultiDataset(args.manifest_dir)
|
||||||
|
|
||||||
train_cuts = multi_dataset.train_cuts()
|
train_cuts = multi_dataset.train_cuts()
|
||||||
# train_cuts = librispeech.train_clean_100_cuts()
|
|
||||||
# if params.full_libri:
|
|
||||||
# train_cuts += librispeech.train_clean_360_cuts()
|
|
||||||
# train_cuts += librispeech.train_other_500_cuts()
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -1228,23 +1222,21 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(
|
train_dl = data_module.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
# valid_cuts = librispeech.dev_clean_cuts()
|
|
||||||
# valid_cuts += librispeech.dev_other_cuts()
|
|
||||||
valid_cuts = multi_dataset.dev_cuts()
|
valid_cuts = multi_dataset.dev_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = data_module.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
# if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
# scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
# model=model,
|
model=model,
|
||||||
# train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
# optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
# sp=sp,
|
sp=sp,
|
||||||
# params=params,
|
params=params,
|
||||||
# )
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -1374,7 +1366,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
AsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user