mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
train.py draft..
This commit is contained in:
parent
ef69661549
commit
3bad661f6f
@ -17,18 +17,21 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import random # temp..
|
||||
from shutil import copyfile
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import DiscreteBottleneckConformer
|
||||
from conformer import BidirectionalConformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
@ -153,7 +156,7 @@ def get_params() -> AttributeDict:
|
||||
"exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"subsampling_factor": 4, # can't be changed
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
@ -166,12 +169,18 @@ def get_params() -> AttributeDict:
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
"accum_grad": 1,
|
||||
"att_rate": 0.7,
|
||||
"att_scale": 0.4,
|
||||
"reverse_att_scale": 0.4, # ctc_scale == 1.0 - att_scale - reverse_att_scale
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_trunk_encoder_layers": 12,
|
||||
"num_decoder_layers": 6,
|
||||
"is_espnet_structure": True,
|
||||
"mmi_loss": False,
|
||||
"num_reverse_encoder_layers": 4,
|
||||
"num_reverse_decoder_layers": 4,
|
||||
"num_self_predictor_layers": 2,
|
||||
"discretization_tot_classes": 512,
|
||||
"discretization_num_groups": 8,
|
||||
"is_bpe": True,
|
||||
"use_feat_batchnorm": True,
|
||||
"max_lrate": 5.0e-04,
|
||||
"first_decay_epoch": 1,
|
||||
@ -270,15 +279,83 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
class LossRecord(collections.defaultdict):
|
||||
def __init__(self):
|
||||
# Passing the type 'int' to the base-class constructor
|
||||
# makes undefined items default to int() which is zero.
|
||||
super(LossRecord, self).__init__(int)
|
||||
|
||||
def __add__(self, other: LossRecord) -> LossRecord:
|
||||
ans = LossRecord()
|
||||
for k, v in self.items():
|
||||
ans[k] = v
|
||||
for k, v in other.items():
|
||||
ans[k] = ans[k] + v
|
||||
return ans
|
||||
|
||||
def __mul__(self, alpha: float) -> LossRecord:
|
||||
ans = LossRecord()
|
||||
for k, v in self.items():
|
||||
ans[k] = v * alpha
|
||||
return ans
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
ans = ''
|
||||
for k, v in self.norm_items():
|
||||
norm_value = '%.2g' % v
|
||||
ans += (str(k) + '=' + str(norm_value) + ', ')
|
||||
frames = str(self['frames'])
|
||||
ans += 'over ' + frames + ' frames.'
|
||||
return ans
|
||||
|
||||
def norm_items(self) -> List[Tuple[string, float]]
|
||||
"""
|
||||
Returns a list of pairs, like:
|
||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||
"""
|
||||
num_frames = self['frames'] if 'frames' in self else 1
|
||||
ans = []
|
||||
for k, v in self.items():
|
||||
if k != 'frames':
|
||||
norm_value = float(v) / num_frames
|
||||
ans.append((k, norm_value))
|
||||
|
||||
|
||||
def reduce(self, device):
|
||||
"""
|
||||
Reduce using torch.distributed, which I believe ensures that
|
||||
all processes get the total.
|
||||
"""
|
||||
keys = sorted(self.keys())
|
||||
s = torch.tensor([ float(self[k]) for k in keys ],
|
||||
device=device)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
for k, v in zip(keys, s.cpu().tolist()):
|
||||
self[k] = v
|
||||
|
||||
def write_summary(self, tb_writer: SummaryWriter, prefix: str, batch_idx: int) -> None:
|
||||
"""
|
||||
Add logging information to a TensorBoard writer.
|
||||
tb_writer: a TensorBoard writer
|
||||
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
||||
or "train/current_"
|
||||
batch_idx: The current batch index, used as the x-axis of the plot.
|
||||
"""
|
||||
for k, v in self.norm_items():
|
||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
):
|
||||
) -> Tuple[Tensor, LossRecord]
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
Compute loss function (including CTC, attention, and reverse-attention terms).
|
||||
|
||||
Args:
|
||||
params:
|
||||
@ -306,9 +383,16 @@ def compute_loss(
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
mmodel = model.module if hasattr(model, "module") else model
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, T, C]
|
||||
memory, position_embedding, memory_mask = model(feature, supervisions)
|
||||
# memory's shape is (N, T, C)
|
||||
|
||||
ctc_output = mmodel.ctc_encoder_forward(memory,
|
||||
position_embedding,
|
||||
memory_mask)
|
||||
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
@ -322,7 +406,7 @@ def compute_loss(
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
ctc_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
@ -335,38 +419,71 @@ def compute_loss(
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
if params.att_scale != 0.0:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||
att_loss = mmodel.decoder_forward(
|
||||
memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
loss = ctc_loss
|
||||
att_loss = torch.tensor([0])
|
||||
att_loss = torch.tensor([0.0]).to(device)
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
||||
if params.reverse_att_scale != 0.0:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
(sampled, softmax,
|
||||
positive_embed_shifted,
|
||||
negative_embed_shifted) = mmodel.sample_forward(memory)
|
||||
|
||||
reverse_decoder_logprob = mmodel.reverse_decoder_forward(
|
||||
positive_embed_shifted,
|
||||
memory_mask,
|
||||
sampled, softmax,
|
||||
token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
padding_id=0)
|
||||
|
||||
self_prediction_logprob = mmodel.self_prediction_forward(
|
||||
negative_embed_shifted,
|
||||
memory_mask,
|
||||
sampled, softmax)
|
||||
|
||||
# Note: reverse_att_loss is the mutual information between
|
||||
# the word sequence and the frames; it will generally be negative,
|
||||
# and is to be minimized (i.e. it goes away from zero as we train,
|
||||
# it does not approach zero).
|
||||
reverse_att_loss = self_prediction_logprob - reverse_decoder_logprob
|
||||
|
||||
if random.random() < 0.01:
|
||||
# Will eventually remove this block..
|
||||
num_frames = supervision_segments[:, 2].sum().item()
|
||||
print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, "
|
||||
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}"
|
||||
f"reverse_att_loss = {reverse_att_loss/num_frames}")
|
||||
else:
|
||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
||||
reverse_att_loss = torch.tensor([0.0]).to(device)
|
||||
|
||||
ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale
|
||||
loss = (ctc_scale * ctc_loss +
|
||||
params.att_scale * att_loss +
|
||||
params.reverse_att_scale * reverse_att_loss)
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
return loss, ctc_loss.detach(), att_loss.detach()
|
||||
info = LossRecord()
|
||||
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
||||
if params.att_scale != 0.0:
|
||||
info['att_loss'] = att_loss.detach().cpu().item()
|
||||
if params.reverse_att_scale != 0.0:
|
||||
info['reverse_att_loss'] = reverse_att_loss.detach().cpu().item()
|
||||
info['loss'] = loss.detach().cpu().item()
|
||||
|
||||
|
||||
return loss, info
|
||||
except RuntimeError as e:
|
||||
print(f"Runtime error. feature.shape = {feature.shape}, supervisions = {supervisions}")
|
||||
raise e
|
||||
@ -381,18 +498,13 @@ def compute_validation_loss(
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
) -> LossRecord:
|
||||
"""Run the validation process. """
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
tot_loss = LossRecord()
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, ctc_loss, att_loss = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
@ -400,36 +512,18 @@ def compute_validation_loss(
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
assert ctc_loss.requires_grad is False
|
||||
assert att_loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
tot_loss += loss_cpu
|
||||
|
||||
tot_ctc_loss += ctc_loss.detach().cpu().item()
|
||||
tot_att_loss += att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.valid_frames
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
s = torch.tensor(
|
||||
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
|
||||
device=loss.device,
|
||||
)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
s = s.cpu().tolist()
|
||||
tot_loss = s[0]
|
||||
tot_ctc_loss = s[1]
|
||||
tot_att_loss = s[2]
|
||||
tot_frames = s[3]
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
params.valid_ctc_loss = tot_ctc_loss / tot_frames
|
||||
params.valid_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = params.valid_loss
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -468,24 +562,20 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
tot_loss = LossInfo()
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss, ctc_loss, att_loss = compute_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
tot_loss = (tot_loss * (1 + 1 / params.reset_interval)) + loss_info # summary stats.
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
@ -495,75 +585,22 @@ def train_one_epoch(
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
ctc_loss_cpu = ctc_loss.detach().cpu().item()
|
||||
att_loss_cpu = att_loss.detach().cpu().item()
|
||||
if batch_idx % 10 == 0:
|
||||
|
||||
tot_frames += params.train_frames
|
||||
tot_loss += loss_cpu
|
||||
tot_ctc_loss += ctc_loss_cpu
|
||||
tot_att_loss += att_loss_cpu
|
||||
|
||||
params.tot_frames += params.train_frames
|
||||
params.tot_loss += loss_cpu
|
||||
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(tb_writer, "train/current_", params.batch_idx_train)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_ctc_loss",
|
||||
ctc_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_att_loss",
|
||||
att_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_ctc_loss",
|
||||
tot_avg_ctc_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_att_loss",
|
||||
tot_avg_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
@ -572,32 +609,14 @@ def train_one_epoch(
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"valid ctc loss {params.valid_ctc_loss:.4f},"
|
||||
f"valid att loss {params.valid_att_loss:.4f},"
|
||||
f"valid loss {params.valid_loss:.4f},"
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
f"Epoch {params.cur_epoch}, validation: {valid_info}"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_ctc_loss",
|
||||
params.valid_ctc_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_att_loss",
|
||||
params.valid_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_loss",
|
||||
params.valid_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
|
||||
|
||||
params.train_loss = params.tot_loss / params.tot_frames
|
||||
|
||||
loss_value = tot_loss['loss'] / tot_loss['frames']
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
@ -647,17 +666,21 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = DiscreteBottleneckConformer(
|
||||
model = BidirectionalConformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
num_trunk_encoder_layers=params.num_trunk_encoder_layers,
|
||||
num_ctc_encoder_layers=params.num_ctc_encoder_layers,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=False,
|
||||
is_espnet_structure=params.is_espnet_structure,
|
||||
mmi_loss=params.mmi_loss,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
num_reverse_encoder_layers=params.num_reverse_encoder_layers,
|
||||
num_reverse_decoder_layers=params.num_reverse_decoder_layers,
|
||||
num_self_predictor_layers=params.num_self_predictor_layers,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
is_bpe=params.is_bpe,
|
||||
discretization_tot_classes=params.discretization_tot_clases,
|
||||
discretization_num_groups=params.discretization_num_groups,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user