mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Minor fixes to support DDP training.
This commit is contained in:
parent
b94d97da37
commit
398ed80d7a
@ -85,10 +85,10 @@ def get_params() -> AttributeDict:
|
|||||||
# - whole-lattice-rescoring
|
# - whole-lattice-rescoring
|
||||||
# - attention-decoder
|
# - attention-decoder
|
||||||
# "method": "whole-lattice-rescoring",
|
# "method": "whole-lattice-rescoring",
|
||||||
"method": "attention-decoder",
|
"method": "1best",
|
||||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||||
# and attention-decoder
|
# and attention-decoder
|
||||||
"num_paths": 1000,
|
"num_paths": 100,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -192,7 +192,7 @@ def decode_one_batch(
|
|||||||
key = f"no_rescore-{params.num_paths}"
|
key = f"no_rescore-{params.num_paths}"
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
assert params.method in [
|
assert params.method in [
|
||||||
@ -234,7 +234,7 @@ def decode_one_batch(
|
|||||||
ans = dict()
|
ans = dict()
|
||||||
for lm_scale_str, best_path in best_path_dict.items():
|
for lm_scale_str, best_path in best_path_dict.items():
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
||||||
ans[lm_scale_str] = hyps
|
ans[lm_scale_str] = hyps
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@ -374,6 +374,8 @@ def main():
|
|||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
HLG.lm_scores = HLG.scores.clone()
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
# HLG = k2.ctc_topo(4999).to(device)
|
||||||
|
|
||||||
if params.method in (
|
if params.method in (
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
@ -383,7 +385,7 @@ def main():
|
|||||||
logging.info("Loading G_4_gram.fst.txt")
|
logging.info("Loading G_4_gram.fst.txt")
|
||||||
logging.warning("It may take 8 minutes.")
|
logging.warning("It may take 8 minutes.")
|
||||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||||
first_word_disambig_id = lexicon.words["#0"]
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|
||||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
# G.aux_labels is not needed in later computations, so
|
# G.aux_labels is not needed in later computations, so
|
||||||
|
@ -130,14 +130,14 @@ def get_params() -> AttributeDict:
|
|||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"start_epoch": 0,
|
"start_epoch": 0,
|
||||||
"num_epochs": 10,
|
"num_epochs": 50,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 10,
|
"log_interval": 10,
|
||||||
"valid_interval": 1000,
|
"valid_interval": 3000,
|
||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
@ -312,16 +312,26 @@ def compute_loss(
|
|||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
att_loss = model.decoder_forward(
|
if hasattr(model, "module"):
|
||||||
encoder_memory,
|
att_loss = model.module.decoder_forward(
|
||||||
memory_mask,
|
encoder_memory,
|
||||||
token_ids=token_ids,
|
memory_mask,
|
||||||
sos_id=graph_compiler.sos_id,
|
token_ids=token_ids,
|
||||||
eos_id=graph_compiler.eos_id,
|
sos_id=graph_compiler.sos_id,
|
||||||
)
|
eos_id=graph_compiler.eos_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
att_loss = model.decoder_forward(
|
||||||
|
encoder_memory,
|
||||||
|
memory_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=graph_compiler.sos_id,
|
||||||
|
eos_id=graph_compiler.eos_id,
|
||||||
|
)
|
||||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||||
else:
|
else:
|
||||||
loss = ctc_loss
|
loss = ctc_loss
|
||||||
|
att_loss = torch.tensor([0])
|
||||||
|
|
||||||
# train_frames and valid_frames are used for printing.
|
# train_frames and valid_frames are used for printing.
|
||||||
if is_training:
|
if is_training:
|
||||||
@ -331,7 +341,7 @@ def compute_loss(
|
|||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
return loss
|
return loss, ctc_loss.detach(), att_loss.detach()
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
@ -347,9 +357,11 @@ def compute_validation_loss(
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tot_loss = 0.0
|
tot_loss = 0.0
|
||||||
|
tot_ctc_loss = 0.0
|
||||||
|
tot_att_loss = 0.0
|
||||||
tot_frames = 0.0
|
tot_frames = 0.0
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss = compute_loss(
|
loss, ctc_loss, att_loss = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
@ -357,19 +369,32 @@ def compute_validation_loss(
|
|||||||
is_training=False,
|
is_training=False,
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
|
assert ctc_loss.requires_grad is False
|
||||||
|
assert att_loss.requires_grad is False
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
loss_cpu = loss.detach().cpu().item()
|
||||||
tot_loss += loss_cpu
|
tot_loss += loss_cpu
|
||||||
|
|
||||||
|
tot_ctc_loss += ctc_loss.detach().cpu().item()
|
||||||
|
tot_att_loss += att_loss.detach().cpu().item()
|
||||||
|
|
||||||
tot_frames += params.valid_frames
|
tot_frames += params.valid_frames
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
|
s = torch.tensor(
|
||||||
|
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
|
||||||
|
device=loss.device,
|
||||||
|
)
|
||||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||||
s = s.cpu().tolist()
|
s = s.cpu().tolist()
|
||||||
tot_loss = s[0]
|
tot_loss = s[0]
|
||||||
tot_frames = s[1]
|
tot_ctc_loss = s[1]
|
||||||
|
tot_att_loss = s[2]
|
||||||
|
tot_frames = s[3]
|
||||||
|
|
||||||
params.valid_loss = tot_loss / tot_frames
|
params.valid_loss = tot_loss / tot_frames
|
||||||
|
params.valid_ctc_loss = tot_ctc_loss / tot_frames
|
||||||
|
params.valid_att_loss = tot_att_loss / tot_frames
|
||||||
|
|
||||||
if params.valid_loss < params.best_valid_loss:
|
if params.valid_loss < params.best_valid_loss:
|
||||||
params.best_valid_epoch = params.cur_epoch
|
params.best_valid_epoch = params.cur_epoch
|
||||||
@ -413,12 +438,15 @@ def train_one_epoch(
|
|||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
tot_loss = 0.0 # sum of losses over all batches
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_ctc_loss = 0.0
|
||||||
|
tot_att_loss = 0.0
|
||||||
|
|
||||||
tot_frames = 0.0 # sum of frames over all batches
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
loss = compute_loss(
|
loss, ctc_loss, att_loss = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
@ -434,19 +462,63 @@ def train_one_epoch(
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
ctc_loss_cpu = ctc_loss.detach().cpu().item()
|
||||||
|
att_loss_cpu = att_loss.detach().cpu().item()
|
||||||
|
|
||||||
tot_frames += params.train_frames
|
tot_frames += params.train_frames
|
||||||
tot_loss += loss_cpu
|
tot_loss += loss_cpu
|
||||||
|
tot_ctc_loss += ctc_loss_cpu
|
||||||
|
tot_att_loss += att_loss_cpu
|
||||||
|
|
||||||
tot_avg_loss = tot_loss / tot_frames
|
tot_avg_loss = tot_loss / tot_frames
|
||||||
|
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
||||||
|
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
|
||||||
|
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
||||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||||
|
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
||||||
|
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
f"batch size: {batch_size}"
|
f"batch size: {batch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_ctc_loss",
|
||||||
|
ctc_loss_cpu / params.train_frames,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_att_loss",
|
||||||
|
att_loss_cpu / params.train_frames,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_loss",
|
||||||
|
loss_cpu / params.train_frames,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_ctc_loss",
|
||||||
|
tot_avg_ctc_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_att_loss",
|
||||||
|
tot_avg_att_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_loss",
|
||||||
|
tot_avg_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
compute_validation_loss(
|
compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -457,7 +529,10 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
|
f"Epoch {params.cur_epoch}, "
|
||||||
|
f"valid ctc loss {params.valid_ctc_loss:.4f},"
|
||||||
|
f"valid att loss {params.valid_att_loss:.4f},"
|
||||||
|
f"valid loss {params.valid_loss:.4f},"
|
||||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||||
f"best valid epoch: {params.best_valid_epoch}"
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
)
|
)
|
||||||
|
@ -659,8 +659,9 @@ def rescore_with_attention_decoder(
|
|||||||
0, path_to_seq_map_long
|
0, path_to_seq_map_long
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
||||||
nll = model.decoder_nll(
|
nll = model.decoder_nll(
|
||||||
expanded_memory, expanded_memory_key_padding_mask, token_ids
|
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
|
||||||
)
|
)
|
||||||
assert nll.ndim == 2
|
assert nll.ndim == 2
|
||||||
assert nll.shape[0] == num_word_seqs
|
assert nll.shape[0] == num_word_seqs
|
||||||
|
Loading…
x
Reference in New Issue
Block a user