Use giga speech dataset as extra training data.

This commit is contained in:
Fangjun Kuang 2022-03-10 10:13:49 +08:00
parent 9071b1420d
commit 35f5a15a54
7 changed files with 233 additions and 39 deletions

View File

@ -1 +1 @@
../transducer/asr_datamodule.py ../transducer_stateless_multi_datasets/asr_datamodule.py

View File

@ -0,0 +1 @@
../transducer_stateless_multi_datasets/gigaspeech.py

View File

@ -0,0 +1 @@
../transducer_stateless_multi_datasets/librispeech.py

View File

@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
from typing import Optional
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -33,6 +35,8 @@ class Transducer(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
decoder_giga: Optional[nn.Module] = None,
joiner_giga: Optional[nn.Module] = None,
): ):
""" """
Args: Args:
@ -49,20 +53,32 @@ class Transducer(nn.Module):
It has two inputs with shapes: (N, T, U, C) and (N, T, U, C). Its It has two inputs with shapes: (N, T, U, C) and (N, T, U, C). Its
output shape is also (N, T, U, C). Note that its output contains output shape is also (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
decoder_giga:
The decoder for the GigaSpeech dataset.
joiner_giga:
The joiner for the GigaSpeech dataset.
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id") assert hasattr(decoder, "blank_id")
if decoder_giga is not None:
assert hasattr(decoder_giga, "blank_id")
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
libri: bool = True,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
@ -77,6 +93,9 @@ class Transducer(nn.Module):
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
libri:
True to use the decoder and joiner for the LibriSpeech dataset.
False to use the decoder and joiner for the GigaSpeech dataset.
prune_range: prune_range:
The prune range for rnnt loss, it means how many symbols(context) The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss. we are considering for each frame to compute the loss.
@ -114,8 +133,15 @@ class Transducer(nn.Module):
# sos_y_padded: [B, S + 1], start with SOS. # sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
if libri:
decoder = self.decoder
joiner = self.joiner
else:
decoder = self.decoder_giga
joiner = self.joiner_giga
# decoder_out: [B, S + 1, C] # decoder_out: [B, S + 1, C]
decoder_out = self.decoder(sos_y_padded) decoder_out = decoder(sos_y_padded)
# Note: y does not start with SOS # Note: y does not start with SOS
# y_padded : [B, S] # y_padded : [B, S]
@ -155,7 +181,7 @@ class Transducer(nn.Module):
) )
# logits : [B, T, prune_range, C] # logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned) logits = joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits, logits=logits,

View File

@ -19,20 +19,44 @@
""" """
Usage: Usage:
cd egs/librispeech/ASR/
./prepare.sh
./prepare_giga_speech.sh
# 100-hours
export CUDA_VISIBLE_DEVICES="0,1"
./pruned_transducer_stateless_multi_datasets/train.py \
--world-size 2 \
--num-epochs 60 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless_multi_datasets/exp-1 \
--full-libri 0 \
--max-duration 300 \
--prune-range 5 \
--lr-factor 1.0 \
--lm-scale 0.25
# 960 hours
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \ ./pruned_transducer_stateless_multi_datasets/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 60 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \ --exp-dir pruned_transducer_stateless_multi_datasets/exp-full \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300 \
--prune-range 5 \
--lr-factor 5.0 \
--lm-scale 0.25
""" """
import argparse import argparse
import logging import logging
import random
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -42,12 +66,15 @@ import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import AsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from gigaspeech import GigaSpeech
from joiner import Joiner from joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -89,6 +116,14 @@ def get_parser():
help="Master port to use for DDP training.", help="Master port to use for DDP training.",
) )
parser.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.",
)
parser.add_argument( parser.add_argument(
"--tensorboard", "--tensorboard",
type=str2bool, type=str2bool,
@ -116,7 +151,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless/exp", default="pruned_transducer_stateless_multi_datasets/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -179,6 +214,13 @@ def get_parser():
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss.",
) )
parser.add_argument(
"--giga-prob",
type=float,
default=0.2,
help="The probability to select a batch from the GigaSpeech dataset",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -253,8 +295,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
# parameters for Noam # parameters for Noam
"warm_step": 80000, # For the 100h subset, use 30000 "warm_step": 80000, # For the 100h subset, use 30000
"env_info": get_env_info(), "env_info": get_env_info(),
@ -302,13 +342,19 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
decoder_giga=decoder_giga,
joiner_giga=joiner_giga,
) )
return model return model
@ -400,6 +446,17 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def is_libri(c: Cut) -> bool:
"""Return True if this cut is from the LibriSpeech dataset.
Note:
During data preparation, we set the custom field in
the supervision segment of GigaSpeech to dict(origin='giga')
See ../local/preprocess_gigaspeech.py.
"""
return c.supervisions[0].custom is None
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -432,6 +489,8 @@ def compute_loss(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
libri = is_libri(supervisions["cut"][0])
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
@ -441,6 +500,7 @@ def compute_loss(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
libri=libri,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
@ -500,7 +560,9 @@ def train_one_epoch(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> None:
@ -519,8 +581,12 @@ def train_one_epoch(
The optimizer we are using. The optimizer we are using.
train_dl: train_dl:
Dataloader for the training dataset. Dataloader for the training dataset.
giga_train_dl:
Dataloader for the GigaSpeech training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
rng:
For selecting which dataset to use.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -528,6 +594,8 @@ def train_one_epoch(
""" """
model.train() model.train()
libri_tot_loss = MetricsTracker()
giga_tot_loss = MetricsTracker()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
def maybe_log_gradients(tag: str): def maybe_log_gradients(tag: str):
@ -569,10 +637,32 @@ def train_one_epoch(
else: else:
optimizer.step() optimizer.step()
for batch_idx, batch in enumerate(train_dl): # index 0: for LibriSpeech
# index 1: for GigaSpeech
# This sets the probabilities for choosing which datasets
dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0
while True:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_libri if idx == 0 else iter_giga
try:
batch = next(dl)
except StopIteration:
break
batch_idx += 1
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
libri = is_libri(batch["supervisions"]["cut"][0])
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -582,6 +672,16 @@ def train_one_epoch(
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
if libri:
libri_tot_loss = (
libri_tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
prefix = "libri" # for logging only
else:
giga_tot_loss = (
giga_tot_loss * (1 - 1 / params.reset_interval)
) + loss_info
prefix = "giga"
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
@ -597,18 +697,29 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, {prefix}_loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}" f"tot_loss[{tot_loss}], "
f"libri_tot_loss[{libri_tot_loss}], "
f"giga_tot_loss[{giga_tot_loss}], "
f"batch size: {batch_size}"
) )
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
if tb_writer is not None: if tb_writer is not None:
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer,
f"train/current_{prefix}_",
params.batch_idx_train,
) )
tot_loss.write_summary( tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train tb_writer, "train/tot_", params.batch_idx_train
) )
libri_tot_loss.write_summary(
tb_writer, "train/libri_tot_", params.batch_idx_train
)
giga_tot_loss.write_summary(
tb_writer, "train/giga_tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -633,6 +744,25 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(cuts)
cuts = cuts.filter(remove_short_and_long_utt)
num_left = len(cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
return cuts
def run(rank, world_size, args): def run(rank, world_size, args):
""" """
Args: Args:
@ -652,6 +782,7 @@ def run(rank, world_size, args):
params.warm_step = 30000 params.warm_step = 30000
fix_random_seed(params.seed) fix_random_seed(params.seed)
rng = random.Random(params.seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -688,7 +819,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model.device = device model.device = device
optimizer = Noam( optimizer = Noam(
@ -702,46 +833,74 @@ def run(rank, world_size, args):
logging.info("Loading optimizer state dict") logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts() train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut): train_cuts = filter_short_and_long_utterances(train_cuts)
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
# XL 10k hours
# L 2.5k hours
# M 1k hours
# S 250 hours
# XS 10 hours
# DEV 12 hours
# Test 40 hours
if params.full_libri:
logging.info("Using the L subset of GigaSpeech (2.5k hours)")
train_giga_cuts = gigaspeech.train_L_cuts()
else:
logging.info("Using the S subset of GigaSpeech (250 hours)")
train_giga_cuts = gigaspeech.train_S_cuts()
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
num_left = len(train_cuts) if args.enable_musan:
num_removed = num_in_total - num_left cuts_musan = load_manifest(
removed_percent = num_removed / num_in_total * 100 Path(args.manifest_dir) / "cuts_musan.json.gz"
)
else:
cuts_musan = None
logging.info(f"Before removing short and long utterances: {num_in_total}") asr_datamodule = AsrDataModule(args)
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
train_dl = librispeech.train_dataloaders(train_cuts) train_dl = asr_datamodule.train_dataloaders(
train_cuts,
dynamic_bucketing=False,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
giga_train_dl = asr_datamodule.train_dataloaders(
train_giga_cuts,
dynamic_bucketing=True,
on_the_fly_feats=True,
cuts_musan=cuts_musan,
)
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom( # It's time consuming to include `giga_train_dl` here
model=model, # for dl in [train_dl, giga_train_dl]:
train_dl=train_dl, for dl in [train_dl]:
optimizer=optimizer, scan_pessimistic_batches_for_oom(
sp=sp, model=model,
params=params, train_dl=dl,
) optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
giga_train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
@ -761,7 +920,9 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
giga_train_dl=giga_train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
rng=rng,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
) )
@ -821,10 +982,12 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) AsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert 0 <= args.giga_prob < 1, args.giga_prob
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: if world_size > 1:

View File

@ -535,10 +535,12 @@ def train_one_epoch(
The optimizer we are using. The optimizer we are using.
train_dl: train_dl:
Dataloader for the training dataset. Dataloader for the training dataset.
giga_train_dl:
Dataloader for the GigaSpeech training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
rng: rng:
For select which dataset to use. For selecting which dataset to use.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:

View File

@ -97,6 +97,7 @@ def get_env_info() -> Dict[str, Any]:
"lhotse-version": lhotse.__version__, "lhotse-version": lhotse.__version__,
"torch-cuda-available": torch.cuda.is_available(), "torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda, "torch-cuda-version": torch.version.cuda,
"torch-version": torch.__version__,
"python-version": sys.version[:3], "python-version": sys.version[:3],
"icefall-git-branch": get_git_branch_name(), "icefall-git-branch": get_git_branch_name(),
"icefall-git-sha1": get_git_sha1(), "icefall-git-sha1": get_git_sha1(),