Add grad_clip and weight-decay, small fix of dataloader and masking

This commit is contained in:
pkufool 2021-08-19 17:34:28 +08:00
parent 5a0b9bcb23
commit 3060c5a556
5 changed files with 39 additions and 19 deletions

View File

@ -284,7 +284,7 @@ def decode_dataset(
results = [] results = []
num_cuts = 0 num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts) tot_num_batches = len(dl)
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -314,9 +314,8 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
logging.info( logging.info(
f"batch {batch_idx}, cuts processed until now is " f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} " f"{num_cuts}"
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
) )
return results return results
@ -376,7 +375,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode") setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started") logging.info("Decoding started")
logging.info(params) logging.info(params)

View File

@ -16,6 +16,7 @@ import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -114,7 +115,9 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0 - log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0 - valid_interval: Run validation if batch_idx % valid_interval is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- beam_size: It is used in k2.ctc_loss - beam_size: It is used in k2.ctc_loss
@ -127,7 +130,7 @@ def get_params() -> AttributeDict:
"exp_dir": Path("conformer_ctc/exp"), "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 0.0, "weight_decay": 1e-6,
"subsampling_factor": 4, "subsampling_factor": 4,
"start_epoch": 0, "start_epoch": 0,
"num_epochs": 50, "num_epochs": 50,
@ -137,11 +140,11 @@ def get_params() -> AttributeDict:
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 10,
"reset_interval": 200,
"valid_interval": 3000, "valid_interval": 3000,
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
#
"accum_grad": 1, "accum_grad": 1,
"att_rate": 0.7, "att_rate": 0.7,
"attention_dim": 512, "attention_dim": 512,
@ -440,6 +443,8 @@ def train_one_epoch(
tot_att_loss = 0.0 tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -457,6 +462,7 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_value_(model.parameters(), 5.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item() loss_cpu = loss.detach().cpu().item()
@ -468,6 +474,9 @@ def train_one_epoch(
tot_ctc_loss += ctc_loss_cpu tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames tot_avg_att_loss = tot_att_loss / tot_frames
@ -516,6 +525,12 @@ def train_one_epoch(
tot_avg_loss, tot_avg_loss,
params.batch_idx_train, params.batch_idx_train,
) )
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( compute_validation_loss(
@ -551,7 +566,7 @@ def train_one_epoch(
params.batch_idx_train, params.batch_idx_train,
) )
params.train_loss = tot_loss / tot_frames params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch

View File

@ -274,9 +274,11 @@ class Transformer(nn.Module):
device device
) )
# TODO: Use eos_id as ignore_id. tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) # We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
tgt = self.decoder_pos(tgt) tgt = self.decoder_pos(tgt)
@ -339,9 +341,11 @@ class Transformer(nn.Module):
device device
) )
# TODO: Use eos_id as ignore_id. tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) # We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
tgt = self.decoder_pos(tgt) tgt = self.decoder_pos(tgt)

View File

@ -171,6 +171,8 @@ class AsrDataModule(DataModule):
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=True, shuffle=True,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
bucket_method='equal_duration',
drop_last=True,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SingleCutSampler.")
@ -184,8 +186,8 @@ class AsrDataModule(DataModule):
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=4, num_workers=2,
persistent_workers=True, persistent_workers=False,
) )
return train_dl return train_dl
@ -214,7 +216,7 @@ class AsrDataModule(DataModule):
sampler=valid_sampler, sampler=valid_sampler,
batch_size=None, batch_size=None,
num_workers=2, num_workers=2,
persistent_workers=True, persistent_workers=False,
) )
return valid_dl return valid_dl

View File

@ -610,7 +610,7 @@ def rescore_with_attention_decoder(
# Since k2.ragged.unique_sequences will reorder paths within a seq, # Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index # `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index. # to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1) # new2old.numel() == unique_word_seq.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
word_seq, need_num_repeats=True, need_new2old_indexes=True word_seq, need_num_repeats=True, need_new2old_indexes=True
) )