Add training code.

This commit is contained in:
Fangjun Kuang 2021-12-13 13:50:53 +08:00
parent 232caf51ee
commit cd5ed7db20
2 changed files with 66 additions and 143 deletions

View File

@ -121,7 +121,7 @@ class Transducer(nn.Module):
logit_lengths=x_lens, logit_lengths=x_lens,
target_lengths=y_lens, target_lengths=y_lens,
blank=blank_id, blank=blank_id,
reduction="mean", reduction="sum",
) )
return loss return loss

View File

@ -25,6 +25,7 @@ from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
@ -36,21 +37,15 @@ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transducer.conformer import Conformer from transducer.conformer import Conformer
from transducer.decoder import Decoder from transducer.decoder import Decoder
from transducer.joiner import Joiner
from transducer.model import Transducer
from transducer.transformer import Noam from transducer.transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():
@ -107,22 +102,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--bpe-model",
type=str, type=str,
default="data/lang_bpe_500", default="data/lang_bpe_500/bpe.model",
help="""The lang dir help="Path to the BPE model",
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--att-rate",
type=float,
default=0.8,
help="""The attention rate.
The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
""",
) )
parser.add_argument( parser.add_argument(
@ -178,16 +161,8 @@ def get_params() -> AttributeDict:
- attention_dim: Hidden dim for multi-head attention model. - attention_dim: Hidden dim for multi-head attention model.
- head: Number of heads of multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
- weight_decay: The weight_decay for the optimizer. - weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
@ -213,16 +188,9 @@ def get_params() -> AttributeDict:
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
# decoder params # decoder params
"vocab_size": 500, # including blank
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"blank_id": 0,
"sos_id": 1,
"num_decoder_layers": 4, "num_decoder_layers": 4,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
# parameters for loss
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6, "weight_decay": 1e-6,
"warm_step": 80000, "warm_step": 80000,
@ -262,6 +230,27 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -352,8 +341,8 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
@ -367,86 +356,35 @@ def compute_loss(
batch: batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it. for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training: is_training:
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
""" """
device = graph_compiler.device device = model.device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) loss = model(x=feature, x_lens=feature_lens, y=y)
# nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
mmodel = model.module if hasattr(model, "module") else model
# Note: We need to generate an unsorted version of token_ids
# `encode_supervisions()` called above sorts text, but
# encoder_memory and memory_mask are not sorted, so we
# use an unsorted version `supervisions["text"]` to regenerate
# the token_ids
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(
supervisions["text"]
)
att_loss = mmodel.decoder_forward(
encoder_memory,
memory_mask,
token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item() info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item() # We use reduction="sum" in computing the loss.
# The displayed loss is the average loss over the batch
info["loss"] = loss.detach().cpu().item() / feature.size(0)
return loss, info return loss, info
@ -454,7 +392,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
graph_compiler: BpeCtcTrainingGraphCompiler, sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> MetricsTracker: ) -> MetricsTracker:
@ -467,8 +405,8 @@ def compute_validation_loss(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp,
batch=batch, batch=batch,
graph_compiler=graph_compiler,
is_training=False, is_training=False,
) )
assert loss.requires_grad is False assert loss.requires_grad is False
@ -489,7 +427,7 @@ def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
@ -508,8 +446,6 @@ def train_one_epoch(
The model for training. The model for training.
optimizer: optimizer:
The optimizer we are using. The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl: train_dl:
Dataloader for the training dataset. Dataloader for the training dataset.
valid_dl: valid_dl:
@ -530,8 +466,8 @@ def train_one_epoch(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp,
batch=batch, batch=batch,
graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
# summary stats # summary stats
@ -567,7 +503,7 @@ def train_one_epoch(
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, sp=sp,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
@ -606,50 +542,37 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else: else:
tb_writer = None tb_writer = None
model = get_encoder_model(params)
model = get_decoder_model(params)
print(model)
return
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler( sp = spm.SentencePieceProcessor()
params.lang_dir, sp.load(params.bpe_model)
device=device,
sos_token="<sos/eos>", # <blk> and <sos/eos> are defined in local/train_bpe_model.py
eos_token="<sos/eos>", params.blank_id = sp.piece_to_id("<blk>")
) params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = Conformer( model = get_transducer_model(params)
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),
@ -659,7 +582,8 @@ def run(rank, world_size, args):
weight_decay=params.weight_decay, weight_decay=params.weight_decay,
) )
if checkpoints: if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
@ -678,7 +602,7 @@ def run(rank, world_size, args):
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
optimizer=optimizer, optimizer=optimizer,
graph_compiler=graph_compiler, sp=sp,
params=params, params=params,
) )
@ -701,7 +625,7 @@ def run(rank, world_size, args):
params=params, params=params,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
graph_compiler=graph_compiler, sp=sp,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
tb_writer=tb_writer, tb_writer=tb_writer,
@ -726,7 +650,7 @@ def scan_pessimistic_batches_for_oom(
model: nn.Module, model: nn.Module,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -742,8 +666,8 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,
sp=sp,
batch=batch, batch=batch,
graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
loss.backward() loss.backward()
@ -766,7 +690,6 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1