Minor fixes to support DDP training.

This commit is contained in:
Fangjun Kuang 2021-07-31 15:26:57 +08:00
parent b94d97da37
commit 398ed80d7a
3 changed files with 99 additions and 21 deletions

View File

@ -85,10 +85,10 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
"method": "1best",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 1000,
"num_paths": 100,
}
)
return params
@ -192,7 +192,7 @@ def decode_one_batch(
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in [
@ -234,7 +234,7 @@ def decode_one_batch(
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans
@ -374,6 +374,8 @@ def main():
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
# HLG = k2.ctc_topo(4999).to(device)
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
@ -383,7 +385,7 @@ def main():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.words["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so

View File

@ -130,14 +130,14 @@ def get_params() -> AttributeDict:
"weight_decay": 0.0,
"subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 10,
"num_epochs": 50,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"valid_interval": 1000,
"valid_interval": 3000,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
@ -312,16 +312,26 @@ def compute_loss(
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
# train_frames and valid_frames are used for printing.
if is_training:
@ -331,7 +341,7 @@ def compute_loss(
assert loss.requires_grad == is_training
return loss
return loss, ctc_loss.detach(), att_loss.detach()
def compute_validation_loss(
@ -347,9 +357,11 @@ def compute_validation_loss(
model.eval()
tot_loss = 0.0
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
loss, ctc_loss, att_loss = compute_loss(
params=params,
model=model,
batch=batch,
@ -357,19 +369,32 @@ def compute_validation_loss(
is_training=False,
)
assert loss.requires_grad is False
assert ctc_loss.requires_grad is False
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()
tot_frames += params.valid_frames
if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
s = torch.tensor(
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]
params.valid_loss = tot_loss / tot_frames
params.valid_ctc_loss = tot_ctc_loss / tot_frames
params.valid_att_loss = tot_att_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
@ -413,12 +438,15 @@ def train_one_epoch(
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss = compute_loss(
loss, ctc_loss, att_loss = compute_loss(
params=params,
model=model,
batch=batch,
@ -434,19 +462,63 @@ def train_one_epoch(
optimizer.step()
loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_ctc_loss",
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params,
@ -457,7 +529,10 @@ def train_one_epoch(
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f"Epoch {params.cur_epoch}, "
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)

View File

@ -659,8 +659,9 @@ def rescore_with_attention_decoder(
0, path_to_seq_map_long
)
# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
expanded_memory, expanded_memory_key_padding_mask, token_ids
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
)
assert nll.ndim == 2
assert nll.shape[0] == num_word_seqs