Use librispeech + gigaspeech with modified conformer.

This commit is contained in:
Fangjun Kuang 2022-04-12 17:28:01 +08:00
parent 337309267b
commit bbf074a36b
5 changed files with 680 additions and 59 deletions

View File

@ -0,0 +1,304 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from typing import Optional
from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import (
BucketingSampler,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
SpecAugment,
)
from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class AsrDataModule:
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler "
"and DynamicBucketingSampler."
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available. Used only in dev/test CutSet",
)
def train_dataloaders(
self,
cuts_train: CutSet,
dynamic_bucketing: bool,
on_the_fly_feats: bool,
cuts_musan: Optional[CutSet] = None,
) -> DataLoader:
"""
Args:
cuts_train:
Cuts for training.
cuts_musan:
If not None, it is the cuts for mixing.
dynamic_bucketing:
True to use DynamicBucketingSampler;
False to use BucketingSampler.
on_the_fly_feats:
True to use OnTheFlyFeatures;
False to use PrecomputedFeatures.
"""
transforms = []
if cuts_musan is not None:
logging.info("Enable MUSAN")
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
)
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if on_the_fly_feats
else PrecomputedFeatures()
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if dynamic_bucketing:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=True,
)
else:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = BucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl

View File

@ -0,0 +1,75 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
class GigaSpeech:
def __init__(self, manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files::
- cuts_XL_raw.jsonl.gz
- cuts_L_raw.jsonl.gz
- cuts_M_raw.jsonl.gz
- cuts_S_raw.jsonl.gz
- cuts_XS_raw.jsonl.gz
- cuts_DEV_raw.jsonl.gz
- cuts_TEST_raw.jsonl.gz
"""
self.manifest_dir = Path(manifest_dir)
def train_XL_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_XL_raw.jsonl.gz"
logging.info(f"About to get train-XL cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def train_L_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_L_raw.jsonl.gz"
logging.info(f"About to get train-L cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def train_M_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_M_raw.jsonl.gz"
logging.info(f"About to get train-M cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def train_S_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_S_raw.jsonl.gz"
logging.info(f"About to get train-S cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def train_XS_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_XS_raw.jsonl.gz"
logging.info(f"About to get train-XS cuts from {f}")
return CutSet.from_jsonl_lazy(f)
def test_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_TEST.jsonl.gz"
logging.info(f"About to get TEST cuts from {f}")
return load_manifest(f)
def dev_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_DEV.jsonl.gz"
logging.info(f"About to get DEV cuts from {f}")
return load_manifest(f)

View File

@ -0,0 +1,74 @@
# Copyright 2021 Piotr Żelasko
# 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
class LibriSpeech:
def __init__(self, manifest_dir: str):
"""
Args:
manifest_dir:
It is expected to contain the following files::
- cuts_dev-clean.json.gz
- cuts_dev-other.json.gz
- cuts_test-clean.json.gz
- cuts_test-other.json.gz
- cuts_train-clean-100.json.gz
- cuts_train-clean-360.json.gz
- cuts_train-other-500.json.gz
"""
self.manifest_dir = Path(manifest_dir)
def train_clean_100_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_train-clean-100.json.gz"
logging.info(f"About to get train-clean-100 cuts from {f}")
return load_manifest(f)
def train_clean_360_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_train-clean-360.json.gz"
logging.info(f"About to get train-clean-360 cuts from {f}")
return load_manifest(f)
def train_other_500_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_train-other-500.json.gz"
logging.info(f"About to get train-other-500 cuts from {f}")
return load_manifest(f)
def test_clean_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_test-clean.json.gz"
logging.info(f"About to get test-clean cuts from {f}")
return load_manifest(f)
def test_other_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_test-other.json.gz"
logging.info(f"About to get test-other cuts from {f}")
return load_manifest(f)
def dev_clean_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_dev-clean.json.gz"
logging.info(f"About to get dev-clean cuts from {f}")
return load_manifest(f)
def dev_other_cuts(self) -> CutSet:
f = self.manifest_dir / "cuts_dev-other.json.gz"
logging.info(f"About to get dev-other cuts from {f}")
return load_manifest(f)

View File

@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
from typing import Optional
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -38,6 +40,8 @@ class Transducer(nn.Module):
decoder_dim: int, decoder_dim: int,
joiner_dim: int, joiner_dim: int,
vocab_size: int, vocab_size: int,
decoder_giga: Optional[nn.Module] = None,
joiner_giga: Optional[nn.Module] = None,
): ):
""" """
Args: Args:
@ -51,11 +55,25 @@ class Transducer(nn.Module):
is (N, U) and its output shape is (N, U, decoder_dim). is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`. It should contain one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). It has two inputs with shapes: (N, T, encoder_dim) and
Its output shape is (N, T, U, vocab_size). Note that its output contains (N, U, decoder_dim). Its output shape is (N, T, U, vocab_size).
Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
encoder_dim:
Output dimension of the encoder network.
decoder_dim:
Output dimension of the decoder network.
joiner_dim:
Input dimension of the joiner network.
vocab_size:
Output dimension of the joiner network.
decoder_giga:
Optional. The decoder network for the GigaSpeech dataset.
joiner_giga:
Optional. The joiner network for the GigaSpeech dataset.
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id") assert hasattr(decoder, "blank_id")
@ -63,16 +81,26 @@ class Transducer(nn.Module):
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga
self.simple_am_proj = ScaledLinear( self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5 encoder_dim, vocab_size, initial_speed=0.5
) )
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_giga is not None:
self.simple_am_proj_giga = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj_giga = ScaledLinear(decoder_dim, vocab_size)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
libri: bool = True,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
@ -88,6 +116,9 @@ class Transducer(nn.Module):
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
libri:
True to use the decoder and joiner for the LibriSpeech dataset.
False to use the decoder and joiner for the GigaSpeech dataset.
prune_range: prune_range:
The prune range for rnnt loss, it means how many symbols(context) The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss. we are considering for each frame to compute the loss.
@ -115,21 +146,32 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
assert torch.all(x_lens > 0) assert torch.all(encoder_out_lens > 0)
if libri:
decoder = self.decoder
simple_lm_proj = self.simple_lm_proj
simple_am_proj = self.simple_am_proj
joiner = self.joiner
else:
decoder = self.decoder_giga
simple_lm_proj = self.simple_lm_proj_giga
simple_am_proj = self.simple_am_proj_giga
joiner = self.joiner_giga
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1) row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS. # sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim] # decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded) decoder_out = decoder(sos_y_padded)
# Note: y does not start with SOS # Note: y does not start with SOS
# y_padded : [B, S] # y_padded : [B, S]
@ -140,10 +182,10 @@ class Transducer(nn.Module):
(x.size(0), 4), dtype=torch.int64, device=x.device (x.size(0), 4), dtype=torch.int64, device=x.device
) )
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out) lm = simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
@ -169,8 +211,8 @@ class Transducer(nn.Module):
# am_pruned : [B, T, prune_range, encoder_dim] # am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning( am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out), am=joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out), lm=joiner.decoder_proj(decoder_out),
ranges=ranges, ranges=ranges,
) )
@ -178,7 +220,7 @@ class Transducer(nn.Module):
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(

View File

@ -21,22 +21,26 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless2/train.py \ cd egs/librispeech/ASR/
./prepare.sh
./prepare_giga_speech.sh
./pruned_transducer_stateless3/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless3/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless2/train.py \ ./pruned_transducer_stateless3/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--use_fp16 1 \ --use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless3/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
@ -45,6 +49,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -56,13 +61,16 @@ import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import AsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from gigaspeech import GigaSpeech
from joiner import Joiner from joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
@ -109,6 +117,14 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.",
)
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
@ -122,7 +138,7 @@ def get_parser():
default=0, default=0,
help="""Resume training from from this epoch. help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from If it is positive, it will load checkpoint from
transducer_stateless2/exp/epoch-{start_epoch-1}.pt transducer_stateless3/exp/epoch-{start_epoch-1}.pt
""", """,
) )
@ -138,7 +154,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless3/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -156,7 +172,8 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need to be changed.", help="The initial learning rate. This value should not need "
"to be changed.",
) )
parser.add_argument( parser.add_argument(
@ -170,7 +187,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lr-epochs", "--lr-epochs",
type=float, type=float,
default=6, default=4,
help="""Number of epochs that affects how rapidly the learning rate decreases. help="""Number of epochs that affects how rapidly the learning rate decreases.
""", """,
) )
@ -262,6 +279,13 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--giga-prob",
type=float,
default=0.5,
help="The probability to select a batch from the GigaSpeech dataset",
)
return parser return parser
@ -377,10 +401,15 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
decoder_giga=decoder_giga,
joiner_giga=joiner_giga,
encoder_dim=params.encoder_dim, encoder_dim=params.encoder_dim,
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
@ -448,9 +477,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params: if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -500,6 +526,17 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def is_libri(c: Cut) -> bool:
"""Return True if this cut is from the LibriSpeech dataset.
Note:
During data preparation, we set the custom field in
the supervision segment of GigaSpeech to dict(origin='giga')
See ../local/preprocess_gigaspeech.py.
"""
return c.supervisions[0].custom is None
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -535,6 +572,8 @@ def compute_loss(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
libri = is_libri(supervisions["cut"][0])
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
@ -544,6 +583,7 @@ def compute_loss(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
libri=libri,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
@ -621,7 +661,9 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: GradScaler, scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -644,8 +686,12 @@ def train_one_epoch(
The learning rate scheduler, we call step() every step. The learning rate scheduler, we call step() every step.
train_dl: train_dl:
Dataloader for the training dataset. Dataloader for the training dataset.
giga_train_dl:
Dataloader for the GigaSpeech training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
rng:
For selecting which dataset to use.
scaler: scaler:
The scaler used for mix precision training. The scaler used for mix precision training.
tb_writer: tb_writer:
@ -658,18 +704,36 @@ def train_one_epoch(
""" """
model.train() model.train()
libri_tot_loss = MetricsTracker()
giga_tot_loss = MetricsTracker()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0) # index 0: for LibriSpeech
# index 1: for GigaSpeech
# This sets the probabilities for choosing which datasets
dl_weights = [1 - params.giga_prob, params.giga_prob]
for batch_idx, batch in enumerate(train_dl): iter_libri = iter(train_dl)
if batch_idx < cur_batch_idx: iter_giga = iter(giga_train_dl)
continue
cur_batch_idx = batch_idx batch_idx = 0
while True:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_libri if idx == 0 else iter_giga
try:
batch = next(dl)
except StopIteration:
break
batch_idx += 1
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
libri = is_libri(batch["supervisions"]["cut"][0])
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
@ -682,6 +746,17 @@ def train_one_epoch(
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
if libri:
libri_tot_loss = (
libri_tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
prefix = "libri" # for logging only
else:
giga_tot_loss = (
giga_tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
prefix = "giga"
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
@ -697,7 +772,6 @@ def train_one_epoch(
params.batch_idx_train > 0 params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0 and params.batch_idx_train % params.save_every_n == 0
): ):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train, global_batch_idx=params.batch_idx_train,
@ -709,7 +783,6 @@ def train_one_epoch(
scaler=scaler, scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx
remove_checkpoints( remove_checkpoints(
out_dir=params.exp_dir, out_dir=params.exp_dir,
topk=params.keep_last_k, topk=params.keep_last_k,
@ -720,8 +793,11 @@ def train_one_epoch(
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], "
f"libri_tot_loss[{libri_tot_loss}], "
f"giga_tot_loss[{giga_tot_loss}], "
f"batch size: {batch_size}"
f"lr: {cur_lr:.2e}" f"lr: {cur_lr:.2e}"
) )
@ -731,11 +807,19 @@ def train_one_epoch(
) )
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer,
f"train/current_{prefix}_",
params.batch_idx_train,
) )
tot_loss.write_summary( tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train tb_writer, "train/tot_", params.batch_idx_train
) )
libri_tot_loss.write_summary(
tb_writer, "train/libri_tot_", params.batch_idx_train
)
giga_tot_loss.write_summary(
tb_writer, "train/giga_tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -760,6 +844,23 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
cuts = cuts.filter(remove_short_and_long_utt)
return cuts
def run(rank, world_size, args): def run(rank, world_size, args):
""" """
Args: Args:
@ -778,6 +879,7 @@ def run(rank, world_size, args):
params.valid_interval = 1600 params.valid_interval = 1600
fix_random_seed(params.seed) fix_random_seed(params.seed)
rng = random.Random(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -814,7 +916,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model.device = device model.device = device
optimizer = Eve(model.parameters(), lr=params.initial_lr) optimizer = Eve(model.parameters(), lr=params.initial_lr)
@ -839,45 +941,65 @@ def run(rank, world_size, args):
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts() train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut): train_cuts = filter_short_and_long_utterances(train_cuts)
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
# XL 10k hours
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # L 2.5k hours
# We only load the sampler's state dict when it loads a checkpoint # M 1k hours
# saved in the middle of an epoch # S 250 hours
sampler_state_dict = checkpoints["sampler"] # XS 10 hours
# DEV 12 hours
# Test 40 hours
if params.full_libri:
logging.info("Using the XL subset of GigaSpeech (10k hours)")
train_giga_cuts = gigaspeech.train_XL_cuts()
else: else:
sampler_state_dict = None logging.info("Using the S subset of GigaSpeech (250 hours)")
train_giga_cuts = gigaspeech.train_S_cuts()
train_dl = librispeech.train_dataloaders( train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
train_cuts, sampler_state_dict=sampler_state_dict
if args.enable_musan:
cuts_musan = load_manifest(
Path(args.manifest_dir) / "cuts_musan.json.gz"
)
else:
cuts_musan = None
asr_datamodule = AsrDataModule(args)
train_dl = asr_datamodule.train_dataloaders(
train_cuts,
dynamic_bucketing=False,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
giga_train_dl = asr_datamodule.train_dataloaders(
train_giga_cuts,
dynamic_bucketing=True,
on_the_fly_feats=True,
cuts_musan=cuts_musan,
) )
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: # It's time consuming to include `giga_train_dl` here
# for dl in [train_dl, giga_train_dl]:
for dl in [train_dl]:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=dl,
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
@ -905,7 +1027,9 @@ def run(rank, world_size, args):
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
giga_train_dl=giga_train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
rng=rng,
scaler=scaler, scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
@ -978,10 +1102,12 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) AsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert 0 <= args.giga_prob < 1, args.giga_prob
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: if world_size > 1: