This commit is contained in:
Yifan Yang 2023-07-13 15:42:43 +08:00
parent 8936365c5c
commit 03abdb3712
9 changed files with 113 additions and 41 deletions

View File

@ -29,7 +29,7 @@ vocab_sizes=(
multidataset=( multidataset=(
"gigaspeech", "gigaspeech",
"commonvoice", "commonvoice",
"peoples_speech", "librilight",
) )
# All files generated by this script are saved in "data". # All files generated by this script are saved in "data".
@ -164,18 +164,18 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
cd ../.. cd ../..
fi fi
# People's Speech # LibriLight
if [[ "${multidataset[@]}" =~ "peoples_speech" ]] && [ ! -f data/fbank/.peoples_speech.done ]; then if [[ "${multidataset[@]}" =~ "librilight" ]] && [ ! -f data/fbank/.librilight.done ]; then
log "Dataset: People's Speech" log "Dataset: LibriLight"
cd data/fbank cd data/fbank
if [ -f ../../../../peoples_speech/ASR/data/fbank/.peoples_speech_train.done ]; then if [ -f ../../../../librilight/ASR/data/fbank/.librilight_train.done ]; then
ln -svf $(realpath ../../../../peoples_speech/ASR/data/fbank/peoples_speech_train_split) . ln -svf $(realpath ../../../../librilight/ASR/data/fbank/librilight_train_split) .
else else
log "Abort! Please run ../../peoples_speech/ASR/prepare.sh --stage 5 --stop-stage 6" log "Abort! Please run ../../librilight/ASR/prepare.sh --stage 5 --stop-stage 6"
exit 1 exit 1
fi fi
touch .peoples_speech.done touch .librilight.done
cd ../.. cd ../..
fi fi
fi fi

View File

@ -88,7 +88,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -116,7 +116,8 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import add_model_arguments, get_params, get_model from multidataset import MultiDataset
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -782,6 +783,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
multidataset = MultiDataset(args.manifest_dir)
test_clean_cuts = librispeech.test_clean_cuts() test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts() test_other_cuts = librispeech.test_other_cuts()
@ -789,8 +791,30 @@ def main():
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_cuts = multidataset.test_cuts()
test_dl = [test_clean_dl, test_other_dl]
gigaspeech_dev_dl = librispeech.test_dataloaders(test_cuts[0])
gigaspeech_test_dl = librispeech.test_dataloaders(test_cuts[1])
commonvoice_dev_dl = librispeech.test_dataloaders(test_cuts[2])
commonvoice_test_dl = librispeech.test_dataloaders(test_cuts[3])
test_sets = [
"librispeech-test-clean",
"librispeech-test-other",
"gigaspeech-dev",
"gigaspeech-test",
"commonvoice-dev",
"commonvoice-test",
]
test_dl = [
test_clean_dl,
test_other_dl,
gigaspeech_dev_dl,
gigaspeech_test_dl,
commonvoice_dev_dl,
commonvoice_test_dl,
]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -76,7 +76,7 @@ import torch.nn as nn
from decoder import Decoder from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2 from zipformer import Zipformer2
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -85,7 +85,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool, make_pad_mask from icefall.utils import make_pad_mask, str2bool
def get_parser(): def get_parser():
@ -182,7 +182,10 @@ class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner""" """A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__( def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear self,
encoder: Zipformer2,
encoder_embed: nn.Module,
encoder_proj: nn.Linear,
): ):
""" """
Args: Args:
@ -210,7 +213,11 @@ class OnnxEncoder(nn.Module):
left_context_len = self.left_context_len left_context_len = self.left_context_len
cached_embed_left_pad = states[-2] cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( (
x,
x_lens,
new_cached_embed_left_pad,
) = self.encoder_embed.streaming_forward(
x=x, x=x,
x_lens=x_lens, x_lens=x_lens,
cached_left_pad=cached_embed_left_pad, cached_left_pad=cached_embed_left_pad,

View File

@ -74,7 +74,7 @@ import torch.nn as nn
from decoder import Decoder from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2 from zipformer import Zipformer2
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -83,7 +83,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool, make_pad_mask from icefall.utils import make_pad_mask, str2bool
def get_parser(): def get_parser():
@ -180,7 +180,10 @@ class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner""" """A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__( def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear self,
encoder: Zipformer2,
encoder_embed: nn.Module,
encoder_proj: nn.Linear,
): ):
""" """
Args: Args:

View File

@ -160,8 +160,9 @@ from typing import List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn from torch import Tensor, nn
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -170,7 +171,6 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.utils import make_pad_mask, str2bool from icefall.utils import make_pad_mask, str2bool
from scaling_converter import convert_scaled_to_non_scaled
def get_parser(): def get_parser():
@ -315,7 +315,11 @@ class StreamingEncoderModel(nn.Module):
left_context_len = self.left_context_len left_context_len = self.left_context_len
cached_embed_left_pad = states[-2] cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( (
x,
x_lens,
new_cached_embed_left_pad,
) = self.encoder_embed.streaming_forward(
x=features, x=features,
x_lens=feature_lengths, x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad, cached_left_pad=cached_embed_left_pad,

View File

@ -71,30 +71,57 @@ class MultiDataset:
self.manifest_dir / f"cv-en_cuts_train.jsonl.gz" self.manifest_dir / f"cv-en_cuts_train.jsonl.gz"
) )
# People's Speech # LibriHeavy
sorted_filenames = sorted( logging.info("Loading LibriHeavy in lazy mode")
glob.glob( libriheavy_small_cuts = load_manifest_lazy(
f"{self.manifest_dir}/peoples_speech_train_split/peoples_speech_cuts_*[yna].*.jsonl.gz" self.manifest_dir / "libriheavy_cuts_train_small.jsonl.gz"
)
) )
libriheavy_medium_cuts = load_manifest_lazy(
logging.info( self.manifest_dir / "libriheavy_cuts_train_medium.jsonl.gz"
f"Loading People's Speech {len(sorted_filenames)} splits in lazy mode"
)
peoples_speech_cuts = lhotse.combine(
lhotse.load_manifest_lazy(p) for p in sorted_filenames
) )
libriheavy_cuts = lhotse.combine(libriheavy_small_cuts, libriheavy_medium_cuts)
return CutSet.mux( return CutSet.mux(
librispeech_cuts, librispeech_cuts,
gigaspeech_cuts, gigaspeech_cuts,
commonvoice_cuts, commonvoice_cuts,
peoples_speech_cuts, libriheavy_cuts,
weights=[ weights=[
len(librispeech_cuts), len(librispeech_cuts),
len(gigaspeech_cuts), len(gigaspeech_cuts),
len(commonvoice_cuts), len(commonvoice_cuts),
len(peoples_speech_cuts), len(libriheavy_cuts),
], ],
) )
def test_cuts(self) -> CutSet:
logging.info("About to get multidataset test cuts")
# GigaSpeech
logging.info("Loading GigaSpeech DEV in lazy mode")
gigaspeech_dev_cuts = load_manifest_lazy(
self.manifest_dir / "cuts_DEV.jsonl.gz"
)
logging.info("Loading GigaSpeech TEST in lazy mode")
gigaspeech_test_cuts = load_manifest_lazy(
self.manifest_dir / "cuts_TEST.jsonl.gz"
)
# CommonVoice
logging.info("Loading CommonVoice DEV in lazy mode")
commonvoice_dev_cuts = load_manifest_lazy(
self.manifest_dir / "cv-en_cuts_dev.jsonl.gz"
)
logging.info("Loading CommonVoice TEST in lazy mode")
commonvoice_test_cuts = load_manifest_lazy(
self.manifest_dir / "cv-en_cuts_test.jsonl.gz"
)
return [
gigaspeech_dev_cuts,
gigaspeech_test_cuts,
commonvoice_dev_cuts,
commonvoice_test_cuts,
]

View File

@ -51,7 +51,7 @@ from streaming_beam_search import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -374,7 +374,11 @@ def streaming_forward(
Returns encoder outputs, output lengths, and updated states. Returns encoder outputs, output lengths, and updated states.
""" """
cached_embed_left_pad = states[-2] cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad) = model.encoder_embed.streaming_forward( (
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features, x=features,
x_lens=feature_lens, x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad, cached_left_pad=cached_embed_left_pad,

View File

@ -66,13 +66,13 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from multidataset import MultiDataset
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
from multidataset import MultiDataset
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
@ -344,7 +344,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lr-hours", "--lr-hours",
type=float, type=float,
default=5000, default=70000,
help="""Number of hours that affects how rapidly the learning rate decreases. help="""Number of hours that affects how rapidly the learning rate decreases.
""", """,
) )
@ -1052,7 +1052,9 @@ def train_one_epoch(
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
) )
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
@ -1387,5 +1389,6 @@ def main():
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()