Finish training code.

This commit is contained in:
Fangjun Kuang 2022-02-16 14:24:34 +08:00
parent e978948a26
commit 018d03cd08
3 changed files with 215 additions and 31 deletions

View File

@ -133,6 +133,15 @@ class AsrDataModule:
help="Path to directory with train/valid/test cuts.", help="Path to directory with train/valid/test cuts.",
) )
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available. Used only in dev/test CutSet",
)
def train_dataloaders( def train_dataloaders(
self, self,
cuts_train: CutSet, cuts_train: CutSet,
@ -240,3 +249,56 @@ class AsrDataModule:
persistent_workers=False, persistent_workers=False,
) )
return train_dl return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = BucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl

View File

@ -34,6 +34,8 @@ class Transducer(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
decoder_giga: nn.Module,
joiner_giga: nn.Module,
): ):
""" """
Args: Args:
@ -50,20 +52,30 @@ class Transducer(nn.Module):
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
decoder_giga:
The decoder for the GigaSpeech dataset.
joiner_giga:
The joiner for the GigaSpeech dataset.
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id") assert hasattr(decoder, "blank_id")
assert hasattr(decoder_giga, "blank_id")
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.decoder_giga = decoder_giga
self.joiner_giga = joiner_giga
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
libri: bool = True,
modified_transducer_prob: float = 0.0, modified_transducer_prob: float = 0.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -76,6 +88,9 @@ class Transducer(nn.Module):
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
libri:
True to use the decoder and joiner for the LibriSpeech dataset.
False to use the decoder and joiner for the GigaSpeech dataset.
modified_transducer_prob: modified_transducer_prob:
The probability to use modified transducer loss. The probability to use modified transducer loss.
Returns: Returns:
@ -100,10 +115,17 @@ class Transducer(nn.Module):
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64) sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out = self.decoder(sos_y_padded) if libri:
decoder = self.decoder
joiner = self.joiner
else:
decoder = self.decoder_giga
joiner = self.joiner_giga
decoder_out = decoder(sos_y_padded)
# +1 here since a blank is prepended to each utterance. # +1 here since a blank is prepended to each utterance.
logits = self.joiner( logits = joiner(
encoder_out=encoder_out, encoder_out=encoder_out,
decoder_out=decoder_out, decoder_out=decoder_out,
encoder_out_len=x_lens, encoder_out_len=x_lens,

View File

@ -21,11 +21,11 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \ ./transducer_stateless_multi_datasets/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp \ --exp-dir transducer_stateless_multi_datasets/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 250 \ --max-duration 250 \
--lr-factor 2.5 --lr-factor 2.5
@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import random
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -43,12 +44,15 @@ import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import AsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from gigaspeech import GigaSpeech
from joiner import Joiner from joiner import Joiner
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -82,6 +86,14 @@ def get_parser():
help="Master port to use for DDP training.", help="Master port to use for DDP training.",
) )
parser.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.",
)
parser.add_argument( parser.add_argument(
"--tensorboard", "--tensorboard",
type=str2bool, type=str2bool,
@ -109,7 +121,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/exp", default="transducer_stateless_multi_datasets/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -259,13 +271,19 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
decoder_giga=decoder_giga,
joiner_giga=joiner_giga,
) )
return model return model
@ -357,6 +375,17 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def is_libri(c: Cut) -> bool:
"""Return True if this cut is from the LibriSpeech dataset.
Note:
During data preparation, we set the custom field in
the supervision segment of GigaSpeech to dict(origin='giga')
See ../local/preprocess_gigaspeech.py.
"""
return c.supervisions[0].custom is None
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -389,6 +418,8 @@ def compute_loss(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
libri = is_libri(supervisions["cut"][0])
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
@ -398,6 +429,7 @@ def compute_loss(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
libri=libri,
modified_transducer_prob=params.modified_transducer_prob, modified_transducer_prob=params.modified_transducer_prob,
) )
@ -452,7 +484,9 @@ def train_one_epoch(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> None:
@ -473,6 +507,8 @@ def train_one_epoch(
Dataloader for the training dataset. Dataloader for the training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
rng:
For select which dataset to use.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -482,7 +518,27 @@ def train_one_epoch(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): # index 0: for LibriSpeech
# index 1: for GigaSpeech
# This sets the probabilities for choosing which datasets
dl_weights = [0.8, 0.2]
iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0
while True:
idx = rng.choices((0, 1), weights=dl_weights, k=1)[0]
dl = iter_libri if idx == 0 else iter_giga
try:
batch = next(dl)
except StopIteration:
break
batch_idx += 1
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -544,6 +600,25 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def filter_short_and_long_utterances(cuts: CutSet) -> CutSet:
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(cuts)
cuts = cuts.filter(remove_short_and_long_utt)
num_left = len(cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}")
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
return cuts
def run(rank, world_size, args): def run(rank, world_size, args):
""" """
Args: Args:
@ -562,7 +637,9 @@ def run(rank, world_size, args):
params.valid_interval = 800 params.valid_interval = 800
params.warm_step = 8000 params.warm_step = 8000
fix_random_seed(42) seed = 42
fix_random_seed(seed)
rng = random.Random(seed)
if world_size > 1: if world_size > 1:
setup_dist(rank, world_size, params.master_port) setup_dist(rank, world_size, params.master_port)
@ -599,7 +676,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model.device = device model.device = device
optimizer = Noam( optimizer = Noam(
@ -613,45 +690,66 @@ def run(rank, world_size, args):
logging.info("Loading optimizer state dict") logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts() train_cuts += librispeech.train_other_500_cuts()
def remove_short_and_long_utt(c: Cut): train_cuts = filter_short_and_long_utterances(train_cuts)
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
# XL 10k hours
# L 2.5k hours
# M 1k hours
# S 250 hours
# XS 10 hours
# DEV 12 hours
# Test 40 hours
# train_giga_cuts = gigaspeech.train_M_cuts()
train_giga_cuts = gigaspeech.train_S_cuts()
train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt) if args.enable_musan:
cuts_musan = load_manifest(
Path(args.manifest_dir) / "cuts_musan.json.gz"
)
else:
cuts_musan = None
num_left = len(train_cuts) asr_datamodule = AsrDataModule(args)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
logging.info(f"Before removing short and long utterances: {num_in_total}") train_dl = asr_datamodule.train_dataloaders(
logging.info(f"After removing short and long utterances: {num_left}") train_cuts,
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") dynamic_bucketing=False,
on_the_fly_feats=False,
cuts_musan=cuts_musan,
)
train_dl = librispeech.train_dataloaders(train_cuts) giga_train_dl = asr_datamodule.train_dataloaders(
train_giga_cuts,
dynamic_bucketing=True,
on_the_fly_feats=True,
cuts_musan=cuts_musan,
)
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = asr_datamodule.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom( for dl in [train_dl, giga_train_dl]:
model=model, scan_pessimistic_batches_for_oom(
train_dl=train_dl, model=model,
optimizer=optimizer, train_dl=dl,
sp=sp, optimizer=optimizer,
params=params, sp=sp,
) params=params,
)
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
giga_train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
@ -671,7 +769,9 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
giga_train_dl=giga_train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
rng=rng,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
) )
@ -731,7 +831,7 @@ def scan_pessimistic_batches_for_oom(
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) AsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)