train2.py not working due to issues in distributed training, hard to fix

This commit is contained in:
Daniel Povey 2021-09-22 12:20:17 +08:00
parent 6f8b7b9c3b
commit 65b737576e

View File

@ -15,20 +15,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#--master-port 12344 --world-size 3 --max-duration=200 --bucketing-sampler=True --start-epoch=5
import argparse
import collections
import logging
from pathlib import Path
import random # temp..
from shutil import copyfile
from typing import Optional
from typing import Optional, Tuple, List
import k2
import torch
from torch import Tensor
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import DiscreteBottleneckConformer
from conformer import BidirectionalConformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
@ -150,10 +155,10 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"),
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_rand"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"subsampling_factor": 4,
"subsampling_factor": 4, # can't be changed
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -166,12 +171,19 @@ def get_params() -> AttributeDict:
"reduction": "sum",
"use_double_scores": True,
"accum_grad": 1,
"att_rate": 0.7,
"att_scale": 0.3,
"reverse_att_scale": 0.2,
"bottleneck_ctc_scale": 0.2, # ctc_scale == 1.0 - att_scale - reverse_att_scale - bottleneck_ctc_scale
"attention_dim": 512,
"nhead": 8,
"num_trunk_encoder_layers": 12,
"num_ctc_encoder_layers": 2,
"num_decoder_layers": 6,
"is_espnet_structure": True,
"mmi_loss": False,
"num_reverse_encoder_layers": 4,
"num_reverse_decoder_layers": 4,
"discretization_tot_classes": 512,
"discretization_num_groups": 8,
"is_bpe": True,
"use_feat_batchnorm": True,
"max_lrate": 5.0e-04,
"first_decay_epoch": 1,
@ -270,15 +282,133 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
class SharedRandom:
def __init__(self):
self.init()
def __call__(self):
count = self.count
if count > self.rand.numel():
self.init()
count = self.count
self.count += 1
return self.rand[count].item()
def init(self):
num_random = 10000
world_size = dist.get_world_size()
rand = torch.rand(num_random, device='cuda')
if world_size > 1:
# Copy from process with rank 0.
torch.distributed.broadcast(rand, src=0)
self.rand = rand.to(device='cpu')
self.count = 0
class LossRecord(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
super(LossRecord, self).__init__(int)
def __add__(self, other: 'LossRecord') -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> 'LossRecord':
ans = LossRecord()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ''
for k, v in self.norm_items():
norm_value = '%.4g' % v
ans += (str(k) + '=' + str(norm_value) + ', ')
frames = str(self['frames'])
ans += 'over ' + frames + ' frames.'
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self['frames'] if 'frames' in self else 1
ans = []
for k, v in self.items():
if k != 'frames':
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([ float(self[k]) for k in keys ],
device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
"""
Add logging information to a TensorBoard writer.
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
def get_loss_type(params: AttributeDict) -> str:
"""
Returns one of: 'ctc', 'bottleneck_ctc', 'attn', 'reverse_attn', with
probabilities as determined by the params.
"""
try:
r = params.shared_random()
except:
params.shared_random = SharedRandom()
r = params.shared_random()
assert 0 <= r and r <= 1
if r < params.att_scale:
return 'att'
else:
r -= params.att_scale
if r < params.reverse_att_scale:
return 'reverse_att'
else:
r -= params.reverse_att_scale
if r < params.bottleneck_ctc_scale:
return 'bottleneck_ctc'
else:
return 'ctc'
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool,
):
) -> Tuple[Tensor, LossRecord]:
"""
Compute CTC loss given the model and its inputs.
Compute loss function (including CTC, attention, and reverse-attention terms).
Args:
params:
@ -306,67 +436,132 @@ def compute_loss(
supervisions = batch["supervisions"]
mmodel = model.module if hasattr(model, "module") else model
loss_type = get_loss_type(params)
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
memory, position_embedding, memory_mask = model(feature, supervisions)
# memory's shape is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
if loss_type == 'ctc':
ctc_output = mmodel.ctc_encoder_forward(memory,
position_embedding,
memory_mask)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
token_ids = graph_compiler.texts_to_ids(texts)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
decoding_graph = graph_compiler.compile(token_ids)
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
elif loss_type == 'att':
loss = mmodel.decoder_forward(
memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
elif loss_type == 'reverse_att':
(sampled, softmax, positive_embed,
positive_embed_shifted,
negative_embed_shifted) = mmodel.sample_forward(memory)
reverse_decoder_logprob = mmodel.reverse_decoder_forward(
positive_embed_shifted,
memory_mask,
sampled, softmax,
token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
padding_id=0)
self_prediction_logprob = mmodel.self_prediction_forward(
negative_embed_shifted,
memory_mask,
sampled, softmax)
# Note: reverse_att_loss is the mutual information between
# the word sequence and the frames; it will generally be negative,
# and is to be minimized (i.e. it goes away from zero as we train,
# it does not approach zero).
loss = self_prediction_logprob - reverse_decoder_logprob
if random.random() < 0.01:
# Will eventually remove this block..
num_frames = supervision_segments[:, 2].sum().item()
print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, "
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, "
f"reverse_att_loss = {reverse_att_loss/num_frames}")
else:
assert loss_type == 'bottleneck_ctc'
(_, _, positive_embed, _, _) = mmodel.sample_forward(memory)
bottleneck_ctc_output = mmodel.bottleneck_ctc_encoder_forward(positive_embed,
position_embedding,
memory_mask)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
bottleneck_ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
assert loss.requires_grad == is_training
return loss, ctc_loss.detach(), att_loss.detach()
info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item()
loss_val = loss.detach().cpu().item()
info[loss_type] = loss_val
info['loss'] = loss_val
loss_scale = 0.001 if (loss_type == 'reverse_att_loss' and params.cur_epoch == 0) else 1.0
# Make sure this output of forward() participates in the
# loss... required for torch.distributed
loss += 0.0 * position_embedding.sum()
return loss * loss_scale, info
except RuntimeError as e:
print(f"Runtime error. feature.shape = {feature.shape}, supervisions = {supervisions}")
raise e
@ -381,18 +576,13 @@ def compute_validation_loss(
graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
) -> LossRecord:
"""Run the validation process. """
model.eval()
tot_loss = 0.0
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
tot_loss = LossRecord()
for batch_idx, batch in enumerate(valid_dl):
loss, ctc_loss, att_loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
@ -400,36 +590,18 @@ def compute_validation_loss(
is_training=False,
)
assert loss.requires_grad is False
assert ctc_loss.requires_grad is False
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
tot_loss = tot_loss + loss_info
if world_size > 1:
s = torch.tensor(
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
tot_loss.reduce(loss.device)
params.valid_loss = tot_loss / tot_frames
params.valid_ctc_loss = tot_ctc_loss / tot_frames
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
loss_value = tot_loss['loss'] / tot_loss['frames']
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
@ -468,24 +640,20 @@ def train_one_epoch(
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_loss = LossRecord()
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, ctc_loss, att_loss = compute_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # summary stats.
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
@ -495,75 +663,22 @@ def train_one_epoch(
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
if batch_idx % 10 == 0:
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if tb_writer is not None:
loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_ctc_loss",
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx % params.valid_interval == 0:
compute_validation_loss(
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
@ -572,32 +687,14 @@ def train_one_epoch(
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, "
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
f"Epoch {params.cur_epoch}, validation: {valid_info}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
params.train_loss = params.tot_loss / params.tot_frames
loss_value = tot_loss['loss'] / tot_loss['frames']
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
@ -647,17 +744,20 @@ def run(rank, world_size, args):
)
logging.info("About to create model")
model = DiscreteBottleneckConformer(
model = BidirectionalConformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
num_trunk_encoder_layers=params.num_trunk_encoder_layers,
num_ctc_encoder_layers=params.num_ctc_encoder_layers,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
num_reverse_encoder_layers=params.num_reverse_encoder_layers,
num_reverse_decoder_layers=params.num_reverse_decoder_layers,
subsampling_factor=params.subsampling_factor,
is_bpe=params.is_bpe,
discretization_tot_classes=params.discretization_tot_classes,
discretization_num_groups=params.discretization_num_groups,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
@ -685,6 +785,7 @@ def run(rank, world_size, args):
for epoch in range(params.start_epoch, params.num_epochs):
optimizer.set_epoch(epoch) # specific to Gloam
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
cur_lr = optimizer._rate
if tb_writer is not None: