Add training code.

This commit is contained in:
Fangjun Kuang 2021-12-13 13:50:53 +08:00
parent 232caf51ee
commit cd5ed7db20
2 changed files with 66 additions and 143 deletions

View File

@ -121,7 +121,7 @@ class Transducer(nn.Module):
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="mean",
reduction="sum",
)
return loss

View File

@ -25,6 +25,7 @@ from shutil import copyfile
from typing import Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
@ -36,21 +37,15 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transducer.conformer import Conformer
from transducer.decoder import Decoder
from transducer.joiner import Joiner
from transducer.model import Transducer
from transducer.transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
setup_logger,
str2bool,
)
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
@ -107,22 +102,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--bpe-model",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--att-rate",
type=float,
default=0.8,
help="""The attention rate.
The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
""",
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
@ -178,16 +161,8 @@ def get_params() -> AttributeDict:
- attention_dim: Hidden dim for multi-head attention model.
- head: Number of heads of multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
@ -213,16 +188,9 @@ def get_params() -> AttributeDict:
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"vocab_size": 500, # including blank
"decoder_embedding_dim": 1024,
"blank_id": 0,
"sos_id": 1,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
# parameters for loss
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000,
@ -262,6 +230,27 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
@ -352,8 +341,8 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
@ -367,86 +356,35 @@ def compute_loss(
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
mmodel = model.module if hasattr(model, "module") else model
# Note: We need to generate an unsorted version of token_ids
# `encode_supervisions()` called above sorts text, but
# encoder_memory and memory_mask are not sorted, so we
# use an unsorted version `supervisions["text"]` to regenerate
# the token_ids
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(
supervisions["text"]
)
att_loss = mmodel.decoder_forward(
encoder_memory,
memory_mask,
token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
loss = model(x=feature, x_lens=feature_lens, y=y)
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["loss"] = loss.detach().cpu().item()
# We use reduction="sum" in computing the loss.
# The displayed loss is the average loss over the batch
info["loss"] = loss.detach().cpu().item() / feature.size(0)
return loss, info
@ -454,7 +392,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: BpeCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@ -467,8 +405,8 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
@ -489,7 +427,7 @@ def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
@ -508,8 +446,6 @@ def train_one_epoch(
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
@ -530,8 +466,8 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats
@ -567,7 +503,7 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
sp=sp,
valid_dl=valid_dl,
world_size=world_size,
)
@ -606,50 +542,37 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
model = get_encoder_model(params)
model = get_decoder_model(params)
print(model)
return
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
model = get_transducer_model(params)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
@ -659,7 +582,8 @@ def run(rank, world_size, args):
weight_decay=params.weight_decay,
)
if checkpoints:
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
@ -678,7 +602,7 @@ def run(rank, world_size, args):
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
sp=sp,
params=params,
)
@ -701,7 +625,7 @@ def run(rank, world_size, args):
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
@ -726,7 +650,7 @@ def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -742,8 +666,8 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
loss.backward()
@ -766,7 +690,6 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size
assert world_size >= 1