train.py draft..

This commit is contained in:
Daniel Povey 2021-09-19 21:18:12 +08:00
parent ef69661549
commit 3bad661f6f

View File

@ -17,18 +17,21 @@
import argparse import argparse
import collections
import logging import logging
from pathlib import Path from pathlib import Path
import random # temp..
from shutil import copyfile from shutil import copyfile
from typing import Optional from typing import Optional, Tuple
import k2 import k2
import torch import torch
from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import DiscreteBottleneckConformer from conformer import BidirectionalConformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
@ -153,7 +156,7 @@ def get_params() -> AttributeDict:
"exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"), "exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4, # can't be changed
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
@ -166,12 +169,18 @@ def get_params() -> AttributeDict:
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
"accum_grad": 1, "accum_grad": 1,
"att_rate": 0.7, "att_scale": 0.4,
"reverse_att_scale": 0.4, # ctc_scale == 1.0 - att_scale - reverse_att_scale
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_trunk_encoder_layers": 12,
"num_decoder_layers": 6, "num_decoder_layers": 6,
"is_espnet_structure": True, "num_reverse_encoder_layers": 4,
"mmi_loss": False, "num_reverse_decoder_layers": 4,
"num_self_predictor_layers": 2,
"discretization_tot_classes": 512,
"discretization_num_groups": 8,
"is_bpe": True,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"max_lrate": 5.0e-04, "max_lrate": 5.0e-04,
"first_decay_epoch": 1, "first_decay_epoch": 1,
@ -270,15 +279,83 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: LossRecord) -> LossRecord:
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> LossRecord:
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.2g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[string, float]]
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
batch: dict, batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ) -> Tuple[Tensor, LossRecord]
""" """
Compute CTC loss given the model and its inputs. Compute loss function (including CTC, attention, and reverse-attention terms).
Args: Args:
params: params:
@ -306,9 +383,16 @@ def compute_loss(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
mmodel = model.module if hasattr(model, "module") else model
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) memory, position_embedding, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # memory's shape is (N, T, C)
ctc_output = mmodel.ctc_encoder_forward(memory,
position_embedding,
memory_mask)
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by
@ -322,7 +406,7 @@ def compute_loss(
decoding_graph = graph_compiler.compile(token_ids) decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, ctc_output,
supervision_segments, supervision_segments,
allow_truncate=params.subsampling_factor - 1, allow_truncate=params.subsampling_factor - 1,
) )
@ -335,38 +419,71 @@ def compute_loss(
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
) )
if params.att_rate != 0.0: if params.att_scale != 0.0:
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if hasattr(model, "module"): att_loss = mmodel.decoder_forward(
att_loss = model.module.decoder_forward( memory,
encoder_memory, memory_mask,
memory_mask, token_ids=token_ids,
token_ids=token_ids, sos_id=graph_compiler.sos_id,
sos_id=graph_compiler.sos_id, eos_id=graph_compiler.eos_id,
eos_id=graph_compiler.eos_id, )
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else: else:
loss = ctc_loss att_loss = torch.tensor([0.0]).to(device)
att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing. if params.reverse_att_scale != 0.0:
if is_training: with torch.set_grad_enabled(is_training):
params.train_frames = supervision_segments[:, 2].sum().item() (sampled, softmax,
positive_embed_shifted,
negative_embed_shifted) = mmodel.sample_forward(memory)
reverse_decoder_logprob = mmodel.reverse_decoder_forward(
positive_embed_shifted,
memory_mask,
sampled, softmax,
token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
padding_id=0)
self_prediction_logprob = mmodel.self_prediction_forward(
negative_embed_shifted,
memory_mask,
sampled, softmax)
# Note: reverse_att_loss is the mutual information between
# the word sequence and the frames; it will generally be negative,
# and is to be minimized (i.e. it goes away from zero as we train,
# it does not approach zero).
reverse_att_loss = self_prediction_logprob - reverse_decoder_logprob
if random.random() < 0.01:
# Will eventually remove this block..
num_frames = supervision_segments[:, 2].sum().item()
print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, "
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}"
f"reverse_att_loss = {reverse_att_loss/num_frames}")
else: else:
params.valid_frames = supervision_segments[:, 2].sum().item() reverse_att_loss = torch.tensor([0.0]).to(device)
ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale
loss = (ctc_scale * ctc_loss +
params.att_scale * att_loss +
params.reverse_att_scale * reverse_att_loss)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
return loss, ctc_loss.detach(), att_loss.detach() info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item()
info['ctc_loss'] = ctc_loss.detach().cpu().item()
if params.att_scale != 0.0:
info['att_loss'] = att_loss.detach().cpu().item()
if params.reverse_att_scale != 0.0:
info['reverse_att_loss'] = reverse_att_loss.detach().cpu().item()
info['loss'] = loss.detach().cpu().item()
return loss, info
except RuntimeError as e: except RuntimeError as e:
print(f"Runtime error. feature.shape = {feature.shape}, supervisions = {supervisions}") print(f"Runtime error. feature.shape = {feature.shape}, supervisions = {supervisions}")
raise e raise e
@ -381,18 +498,13 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler, graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> LossRecord:
"""Run the validation process. The validation loss """Run the validation process. """
is saved in `params.valid_loss`.
"""
model.eval() model.eval()
tot_loss = 0.0 tot_loss = LossRecord()
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
loss, ctc_loss, att_loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
@ -400,36 +512,18 @@ def compute_validation_loss(
is_training=False, is_training=False,
) )
assert loss.requires_grad is False assert loss.requires_grad is False
assert ctc_loss.requires_grad is False tot_loss = tot_loss + loss_info
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
if world_size > 1: if world_size > 1:
s = torch.tensor( tot_loss.reduce(loss.device)
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
params.valid_loss = tot_loss / tot_frames loss_value = tot_loss['loss'] / tot_loss['frames']
params.valid_ctc_loss = tot_ctc_loss / tot_frames if loss_value < params.best_valid_loss:
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
@ -468,24 +562,20 @@ def train_one_epoch(
""" """
model.train() model.train()
tot_loss = 0.0 # sum of losses over all batches tot_loss = LossInfo()
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, ctc_loss, att_loss = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
) )
tot_loss = (tot_loss * (1 + 1 / params.reset_interval)) + loss_info # summary stats.
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
@ -495,75 +585,22 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item() if batch_idx % 10 == 0:
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames if tb_writer is not None:
tot_loss += loss_cpu loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
tot_ctc_loss += ctc_loss_cpu tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}], "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " f"tot_loss[{tot_loss}], batch size: {batch_size}"
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
) )
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_ctc_loss",
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
@ -572,32 +609,14 @@ def train_one_epoch(
) )
model.train() model.train()
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, validation: {valid_info}"
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
) )
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar( valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames
loss_value = tot_loss['loss'] / tot_loss['frames']
params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
@ -647,17 +666,21 @@ def run(rank, world_size, args):
) )
logging.info("About to create model") logging.info("About to create model")
model = DiscreteBottleneckConformer( model = BidirectionalConformer(
num_features=params.feature_dim, num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes, num_classes=num_classes,
subsampling_factor=params.subsampling_factor, d_model=params.attention_dim,
nhead=params.nhead,
num_trunk_encoder_layers=params.num_trunk_encoder_layers,
num_ctc_encoder_layers=params.num_ctc_encoder_layers,
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False, num_reverse_encoder_layers=params.num_reverse_encoder_layers,
is_espnet_structure=params.is_espnet_structure, num_reverse_decoder_layers=params.num_reverse_decoder_layers,
mmi_loss=params.mmi_loss, num_self_predictor_layers=params.num_self_predictor_layers,
use_feat_batchnorm=params.use_feat_batchnorm, subsampling_factor=params.subsampling_factor,
is_bpe=params.is_bpe,
discretization_tot_classes=params.discretization_tot_clases,
discretization_num_groups=params.discretization_num_groups,
) )
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)