From 3060c5a55686d910bb0282efd0f53e03e9325b78 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 19 Aug 2021 17:34:28 +0800 Subject: [PATCH] Add grad_clip and weight-decay, small fix of dataloader and masking --- egs/librispeech/ASR/conformer_ctc/decode.py | 9 ++++---- egs/librispeech/ASR/conformer_ctc/train.py | 23 +++++++++++++++---- .../ASR/conformer_ctc/transformer.py | 16 ++++++++----- icefall/dataset/asr_datamodule.py | 8 ++++--- icefall/decode.py | 2 +- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 889a0a474..fbb5e096b 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -284,7 +284,7 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) + tot_num_batches = len(dl) results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -314,9 +314,8 @@ def decode_dataset( if batch_idx % 100 == 0: logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is " + f"{num_cuts}" ) return results @@ -376,7 +375,7 @@ def main(): params = get_params() params.update(vars(args)) - setup_logger(f"{params.exp_dir}/log/log-decode") + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") logging.info("Decoding started") logging.info(params) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 552db81ec..c2f98fc62 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -16,6 +16,7 @@ import torch.nn as nn from conformer import Conformer from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_value_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -114,7 +115,9 @@ def get_params() -> AttributeDict: - log_interval: Print training loss if batch_idx % log_interval` is 0 - - valid_interval: Run validation if batch_idx % valid_interval` is 0 + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - beam_size: It is used in k2.ctc_loss @@ -127,7 +130,7 @@ def get_params() -> AttributeDict: "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, - "weight_decay": 0.0, + "weight_decay": 1e-6, "subsampling_factor": 4, "start_epoch": 0, "num_epochs": 50, @@ -137,11 +140,11 @@ def get_params() -> AttributeDict: "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, + "reset_interval": 200, "valid_interval": 3000, "beam_size": 10, "reduction": "sum", "use_double_scores": True, - # "accum_grad": 1, "att_rate": 0.7, "attention_dim": 512, @@ -440,6 +443,8 @@ def train_one_epoch( tot_att_loss = 0.0 tot_frames = 0.0 # sum of frames over all batches + params.tot_loss = 0.0 + params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -457,6 +462,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() + clip_grad_value_(model.parameters(), 5.0) optimizer.step() loss_cpu = loss.detach().cpu().item() @@ -468,6 +474,9 @@ def train_one_epoch( tot_ctc_loss += ctc_loss_cpu tot_att_loss += att_loss_cpu + params.tot_frames += params.train_frames + params.tot_loss += loss_cpu + tot_avg_loss = tot_loss / tot_frames tot_avg_ctc_loss = tot_ctc_loss / tot_frames tot_avg_att_loss = tot_att_loss / tot_frames @@ -516,6 +525,12 @@ def train_one_epoch( tot_avg_loss, params.batch_idx_train, ) + if batch_idx > 0 and batch_idx % params.reset_interval == 0: + tot_loss = 0.0 # sum of losses over all batches + tot_ctc_loss = 0.0 + tot_att_loss = 0.0 + + tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: compute_validation_loss( @@ -551,7 +566,7 @@ def train_one_epoch( params.batch_idx_train, ) - params.train_loss = tot_loss / tot_frames + params.train_loss = params.tot_loss / params.tot_frames if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index a974be4e0..1767da7e8 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -274,9 +274,11 @@ class Transformer(nn.Module): device ) - # TODO: Use eos_id as ignore_id. - # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_pos(tgt) @@ -339,9 +341,11 @@ class Transformer(nn.Module): device ) - # TODO: Use eos_id as ignore_id. - # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) tgt = self.decoder_pos(tgt) diff --git a/icefall/dataset/asr_datamodule.py b/icefall/dataset/asr_datamodule.py index aae7af9ce..73eef9c31 100644 --- a/icefall/dataset/asr_datamodule.py +++ b/icefall/dataset/asr_datamodule.py @@ -171,6 +171,8 @@ class AsrDataModule(DataModule): max_duration=self.args.max_duration, shuffle=True, num_buckets=self.args.num_buckets, + bucket_method='equal_duration', + drop_last=True, ) else: logging.info("Using SingleCutSampler.") @@ -184,8 +186,8 @@ class AsrDataModule(DataModule): train, sampler=train_sampler, batch_size=None, - num_workers=4, - persistent_workers=True, + num_workers=2, + persistent_workers=False, ) return train_dl @@ -214,7 +216,7 @@ class AsrDataModule(DataModule): sampler=valid_sampler, batch_size=None, num_workers=2, - persistent_workers=True, + persistent_workers=False, ) return valid_dl diff --git a/icefall/decode.py b/icefall/decode.py index 0e9baf2e4..2d3e1ed56 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -610,7 +610,7 @@ def rescore_with_attention_decoder( # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. - # new2old.numel() == unique_word_seqs.tot_size(1) + # new2old.numel() == unique_word_seq.tot_size(1) unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( word_seq, need_num_repeats=True, need_new2old_indexes=True )