mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Minor fixes to support DDP training.
This commit is contained in:
parent
b94d97da37
commit
398ed80d7a
@ -85,10 +85,10 @@ def get_params() -> AttributeDict:
|
||||
# - whole-lattice-rescoring
|
||||
# - attention-decoder
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "attention-decoder",
|
||||
"method": "1best",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# and attention-decoder
|
||||
"num_paths": 1000,
|
||||
"num_paths": 100,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -192,7 +192,7 @@ def decode_one_batch(
|
||||
key = f"no_rescore-{params.num_paths}"
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in [
|
||||
@ -234,7 +234,7 @@ def decode_one_batch(
|
||||
ans = dict()
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
@ -374,6 +374,8 @@ def main():
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
# HLG = k2.ctc_topo(4999).to(device)
|
||||
|
||||
if params.method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
@ -383,7 +385,7 @@ def main():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.words["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
|
@ -130,14 +130,14 @@ def get_params() -> AttributeDict:
|
||||
"weight_decay": 0.0,
|
||||
"subsampling_factor": 4,
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 10,
|
||||
"num_epochs": 50,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 10,
|
||||
"valid_interval": 1000,
|
||||
"valid_interval": 3000,
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
@ -312,16 +312,26 @@ def compute_loss(
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = ctc_loss
|
||||
att_loss = torch.tensor([0])
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
@ -331,7 +341,7 @@ def compute_loss(
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
return loss
|
||||
return loss, ctc_loss.detach(), att_loss.detach()
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
@ -347,9 +357,11 @@ def compute_validation_loss(
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss = compute_loss(
|
||||
loss, ctc_loss, att_loss = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
@ -357,19 +369,32 @@ def compute_validation_loss(
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
assert ctc_loss.requires_grad is False
|
||||
assert att_loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
tot_loss += loss_cpu
|
||||
|
||||
tot_ctc_loss += ctc_loss.detach().cpu().item()
|
||||
tot_att_loss += att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.valid_frames
|
||||
|
||||
if world_size > 1:
|
||||
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
|
||||
s = torch.tensor(
|
||||
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
|
||||
device=loss.device,
|
||||
)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
s = s.cpu().tolist()
|
||||
tot_loss = s[0]
|
||||
tot_frames = s[1]
|
||||
tot_ctc_loss = s[1]
|
||||
tot_att_loss = s[2]
|
||||
tot_frames = s[3]
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
params.valid_ctc_loss = tot_ctc_loss / tot_frames
|
||||
params.valid_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
@ -413,12 +438,15 @@ def train_one_epoch(
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_ctc_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss = compute_loss(
|
||||
loss, ctc_loss, att_loss = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
@ -434,19 +462,63 @@ def train_one_epoch(
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
ctc_loss_cpu = ctc_loss.detach().cpu().item()
|
||||
att_loss_cpu = att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.train_frames
|
||||
tot_loss += loss_cpu
|
||||
tot_ctc_loss += ctc_loss_cpu
|
||||
tot_att_loss += att_loss_cpu
|
||||
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_ctc_loss",
|
||||
ctc_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_att_loss",
|
||||
att_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_ctc_loss",
|
||||
tot_avg_ctc_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_att_loss",
|
||||
tot_avg_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
params=params,
|
||||
@ -457,7 +529,10 @@ def train_one_epoch(
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"valid ctc loss {params.valid_ctc_loss:.4f},"
|
||||
f"valid att loss {params.valid_att_loss:.4f},"
|
||||
f"valid loss {params.valid_loss:.4f},"
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
|
@ -659,8 +659,9 @@ def rescore_with_attention_decoder(
|
||||
0, path_to_seq_map_long
|
||||
)
|
||||
|
||||
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
||||
nll = model.decoder_nll(
|
||||
expanded_memory, expanded_memory_key_padding_mask, token_ids
|
||||
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == num_word_seqs
|
||||
|
Loading…
x
Reference in New Issue
Block a user