mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
update
This commit is contained in:
parent
8936365c5c
commit
03abdb3712
@ -29,7 +29,7 @@ vocab_sizes=(
|
|||||||
multidataset=(
|
multidataset=(
|
||||||
"gigaspeech",
|
"gigaspeech",
|
||||||
"commonvoice",
|
"commonvoice",
|
||||||
"peoples_speech",
|
"librilight",
|
||||||
)
|
)
|
||||||
|
|
||||||
# All files generated by this script are saved in "data".
|
# All files generated by this script are saved in "data".
|
||||||
@ -164,18 +164,18 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
cd ../..
|
cd ../..
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# People's Speech
|
# LibriLight
|
||||||
if [[ "${multidataset[@]}" =~ "peoples_speech" ]] && [ ! -f data/fbank/.peoples_speech.done ]; then
|
if [[ "${multidataset[@]}" =~ "librilight" ]] && [ ! -f data/fbank/.librilight.done ]; then
|
||||||
log "Dataset: People's Speech"
|
log "Dataset: LibriLight"
|
||||||
cd data/fbank
|
cd data/fbank
|
||||||
if [ -f ../../../../peoples_speech/ASR/data/fbank/.peoples_speech_train.done ]; then
|
if [ -f ../../../../librilight/ASR/data/fbank/.librilight_train.done ]; then
|
||||||
ln -svf $(realpath ../../../../peoples_speech/ASR/data/fbank/peoples_speech_train_split) .
|
ln -svf $(realpath ../../../../librilight/ASR/data/fbank/librilight_train_split) .
|
||||||
else
|
else
|
||||||
log "Abort! Please run ../../peoples_speech/ASR/prepare.sh --stage 5 --stop-stage 6"
|
log "Abort! Please run ../../librilight/ASR/prepare.sh --stage 5 --stop-stage 6"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
touch .peoples_speech.done
|
touch .librilight.done
|
||||||
cd ../..
|
cd ../..
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -116,7 +116,8 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_model
|
from multidataset import MultiDataset
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -782,6 +783,7 @@ def main():
|
|||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
multidataset = MultiDataset(args.manifest_dir)
|
||||||
|
|
||||||
test_clean_cuts = librispeech.test_clean_cuts()
|
test_clean_cuts = librispeech.test_clean_cuts()
|
||||||
test_other_cuts = librispeech.test_other_cuts()
|
test_other_cuts = librispeech.test_other_cuts()
|
||||||
@ -789,8 +791,30 @@ def main():
|
|||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_cuts = multidataset.test_cuts()
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
|
||||||
|
gigaspeech_dev_dl = librispeech.test_dataloaders(test_cuts[0])
|
||||||
|
gigaspeech_test_dl = librispeech.test_dataloaders(test_cuts[1])
|
||||||
|
commonvoice_dev_dl = librispeech.test_dataloaders(test_cuts[2])
|
||||||
|
commonvoice_test_dl = librispeech.test_dataloaders(test_cuts[3])
|
||||||
|
|
||||||
|
test_sets = [
|
||||||
|
"librispeech-test-clean",
|
||||||
|
"librispeech-test-other",
|
||||||
|
"gigaspeech-dev",
|
||||||
|
"gigaspeech-test",
|
||||||
|
"commonvoice-dev",
|
||||||
|
"commonvoice-test",
|
||||||
|
]
|
||||||
|
|
||||||
|
test_dl = [
|
||||||
|
test_clean_dl,
|
||||||
|
test_other_dl,
|
||||||
|
gigaspeech_dev_dl,
|
||||||
|
gigaspeech_test_dl,
|
||||||
|
commonvoice_dev_dl,
|
||||||
|
commonvoice_test_dl,
|
||||||
|
]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
|
@ -76,7 +76,7 @@ import torch.nn as nn
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -85,7 +85,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool, make_pad_mask
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -182,7 +182,10 @@ class OnnxEncoder(nn.Module):
|
|||||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
self,
|
||||||
|
encoder: Zipformer2,
|
||||||
|
encoder_embed: nn.Module,
|
||||||
|
encoder_proj: nn.Linear,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -210,7 +213,11 @@ class OnnxEncoder(nn.Module):
|
|||||||
left_context_len = self.left_context_len
|
left_context_len = self.left_context_len
|
||||||
|
|
||||||
cached_embed_left_pad = states[-2]
|
cached_embed_left_pad = states[-2]
|
||||||
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
(
|
||||||
|
x,
|
||||||
|
x_lens,
|
||||||
|
new_cached_embed_left_pad,
|
||||||
|
) = self.encoder_embed.streaming_forward(
|
||||||
x=x,
|
x=x,
|
||||||
x_lens=x_lens,
|
x_lens=x_lens,
|
||||||
cached_left_pad=cached_embed_left_pad,
|
cached_left_pad=cached_embed_left_pad,
|
||||||
|
@ -74,7 +74,7 @@ import torch.nn as nn
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -83,7 +83,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool, make_pad_mask
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -180,7 +180,10 @@ class OnnxEncoder(nn.Module):
|
|||||||
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
self,
|
||||||
|
encoder: Zipformer2,
|
||||||
|
encoder_embed: nn.Module,
|
||||||
|
encoder_proj: nn.Linear,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -160,8 +160,9 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -170,7 +171,6 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask, str2bool
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -315,7 +315,11 @@ class StreamingEncoderModel(nn.Module):
|
|||||||
left_context_len = self.left_context_len
|
left_context_len = self.left_context_len
|
||||||
|
|
||||||
cached_embed_left_pad = states[-2]
|
cached_embed_left_pad = states[-2]
|
||||||
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
(
|
||||||
|
x,
|
||||||
|
x_lens,
|
||||||
|
new_cached_embed_left_pad,
|
||||||
|
) = self.encoder_embed.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lengths,
|
x_lens=feature_lengths,
|
||||||
cached_left_pad=cached_embed_left_pad,
|
cached_left_pad=cached_embed_left_pad,
|
||||||
|
@ -71,30 +71,57 @@ class MultiDataset:
|
|||||||
self.manifest_dir / f"cv-en_cuts_train.jsonl.gz"
|
self.manifest_dir / f"cv-en_cuts_train.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
# People's Speech
|
# LibriHeavy
|
||||||
sorted_filenames = sorted(
|
logging.info("Loading LibriHeavy in lazy mode")
|
||||||
glob.glob(
|
libriheavy_small_cuts = load_manifest_lazy(
|
||||||
f"{self.manifest_dir}/peoples_speech_train_split/peoples_speech_cuts_*[yna].*.jsonl.gz"
|
self.manifest_dir / "libriheavy_cuts_train_small.jsonl.gz"
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
libriheavy_medium_cuts = load_manifest_lazy(
|
||||||
logging.info(
|
self.manifest_dir / "libriheavy_cuts_train_medium.jsonl.gz"
|
||||||
f"Loading People's Speech {len(sorted_filenames)} splits in lazy mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
peoples_speech_cuts = lhotse.combine(
|
|
||||||
lhotse.load_manifest_lazy(p) for p in sorted_filenames
|
|
||||||
)
|
)
|
||||||
|
libriheavy_cuts = lhotse.combine(libriheavy_small_cuts, libriheavy_medium_cuts)
|
||||||
|
|
||||||
return CutSet.mux(
|
return CutSet.mux(
|
||||||
librispeech_cuts,
|
librispeech_cuts,
|
||||||
gigaspeech_cuts,
|
gigaspeech_cuts,
|
||||||
commonvoice_cuts,
|
commonvoice_cuts,
|
||||||
peoples_speech_cuts,
|
libriheavy_cuts,
|
||||||
weights=[
|
weights=[
|
||||||
len(librispeech_cuts),
|
len(librispeech_cuts),
|
||||||
len(gigaspeech_cuts),
|
len(gigaspeech_cuts),
|
||||||
len(commonvoice_cuts),
|
len(commonvoice_cuts),
|
||||||
len(peoples_speech_cuts),
|
len(libriheavy_cuts),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get multidataset test cuts")
|
||||||
|
|
||||||
|
# GigaSpeech
|
||||||
|
logging.info("Loading GigaSpeech DEV in lazy mode")
|
||||||
|
gigaspeech_dev_cuts = load_manifest_lazy(
|
||||||
|
self.manifest_dir / "cuts_DEV.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Loading GigaSpeech TEST in lazy mode")
|
||||||
|
gigaspeech_test_cuts = load_manifest_lazy(
|
||||||
|
self.manifest_dir / "cuts_TEST.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
# CommonVoice
|
||||||
|
logging.info("Loading CommonVoice DEV in lazy mode")
|
||||||
|
commonvoice_dev_cuts = load_manifest_lazy(
|
||||||
|
self.manifest_dir / "cv-en_cuts_dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Loading CommonVoice TEST in lazy mode")
|
||||||
|
commonvoice_test_cuts = load_manifest_lazy(
|
||||||
|
self.manifest_dir / "cv-en_cuts_test.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
gigaspeech_dev_cuts,
|
||||||
|
gigaspeech_test_cuts,
|
||||||
|
commonvoice_dev_cuts,
|
||||||
|
commonvoice_test_cuts,
|
||||||
|
]
|
||||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -374,7 +374,11 @@ def streaming_forward(
|
|||||||
Returns encoder outputs, output lengths, and updated states.
|
Returns encoder outputs, output lengths, and updated states.
|
||||||
"""
|
"""
|
||||||
cached_embed_left_pad = states[-2]
|
cached_embed_left_pad = states[-2]
|
||||||
(x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward(
|
(
|
||||||
|
x,
|
||||||
|
x_lens,
|
||||||
|
new_cached_embed_left_pad,
|
||||||
|
) = model.encoder_embed.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
cached_left_pad=cached_embed_left_pad,
|
cached_left_pad=cached_embed_left_pad,
|
||||||
|
@ -66,13 +66,13 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from multidataset import MultiDataset
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
|
from multidataset import MultiDataset
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from subsampling import Conv2dSubsampling
|
from subsampling import Conv2dSubsampling
|
||||||
@ -344,7 +344,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-hours",
|
"--lr-hours",
|
||||||
type=float,
|
type=float,
|
||||||
default=5000,
|
default=70000,
|
||||||
help="""Number of hours that affects how rapidly the learning rate decreases.
|
help="""Number of hours that affects how rapidly the learning rate decreases.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -1052,7 +1052,9 @@ def train_one_epoch(
|
|||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale",
|
||||||
|
cur_grad_scale,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
@ -1387,5 +1389,6 @@ def main():
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user