mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add grad_clip and weight-decay, small fix of dataloader and masking
This commit is contained in:
parent
5a0b9bcb23
commit
3060c5a556
@ -284,7 +284,7 @@ def decode_dataset(
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
tot_num_cuts = len(dl.dataset.cuts)
|
||||
tot_num_batches = len(dl)
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -314,9 +314,8 @@ def decode_dataset(
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
|
||||
f"{num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
@ -376,7 +375,7 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-decode")
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
|
@ -16,6 +16,7 @@ import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_value_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
@ -114,7 +115,9 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- beam_size: It is used in k2.ctc_loss
|
||||
|
||||
@ -127,7 +130,7 @@ def get_params() -> AttributeDict:
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 0.0,
|
||||
"weight_decay": 1e-6,
|
||||
"subsampling_factor": 4,
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 50,
|
||||
@ -137,11 +140,11 @@ def get_params() -> AttributeDict:
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 10,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
#
|
||||
"accum_grad": 1,
|
||||
"att_rate": 0.7,
|
||||
"attention_dim": 512,
|
||||
@ -440,6 +443,8 @@ def train_one_epoch(
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
@ -457,6 +462,7 @@ def train_one_epoch(
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_value_(model.parameters(), 5.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
@ -468,6 +474,9 @@ def train_one_epoch(
|
||||
tot_ctc_loss += ctc_loss_cpu
|
||||
tot_att_loss += att_loss_cpu
|
||||
|
||||
params.tot_frames += params.train_frames
|
||||
params.tot_loss += loss_cpu
|
||||
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||
@ -516,6 +525,12 @@ def train_one_epoch(
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
@ -551,7 +566,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
params.train_loss = tot_loss / tot_frames
|
||||
params.train_loss = params.tot_loss / params.tot_frames
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
|
@ -274,9 +274,11 @@ class Transformer(nn.Module):
|
||||
device
|
||||
)
|
||||
|
||||
# TODO: Use eos_id as ignore_id.
|
||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
# We set the first column to False since the first column in ys_in_pad
|
||||
# contains sos_id, which is the same as eos_id in our current setting.
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
@ -339,9 +341,11 @@ class Transformer(nn.Module):
|
||||
device
|
||||
)
|
||||
|
||||
# TODO: Use eos_id as ignore_id.
|
||||
# tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad)
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
# We set the first column to False since the first column in ys_in_pad
|
||||
# contains sos_id, which is the same as eos_id in our current setting.
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
|
@ -171,6 +171,8 @@ class AsrDataModule(DataModule):
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=True,
|
||||
num_buckets=self.args.num_buckets,
|
||||
bucket_method='equal_duration',
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
@ -184,8 +186,8 @@ class AsrDataModule(DataModule):
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
return train_dl
|
||||
|
||||
@ -214,7 +216,7 @@ class AsrDataModule(DataModule):
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
persistent_workers=False,
|
||||
)
|
||||
return valid_dl
|
||||
|
||||
|
@ -610,7 +610,7 @@ def rescore_with_attention_decoder(
|
||||
# Since k2.ragged.unique_sequences will reorder paths within a seq,
|
||||
# `new2old` is a 1-D torch.Tensor mapping from the output path index
|
||||
# to the input path index.
|
||||
# new2old.numel() == unique_word_seqs.tot_size(1)
|
||||
# new2old.numel() == unique_word_seq.tot_size(1)
|
||||
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
|
||||
word_seq, need_num_repeats=True, need_new2old_indexes=True
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user