diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 34b99cd2d..98bd47bc1 100644 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -59,10 +59,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -80,10 +77,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=35, - help="Number of epochs to train.", + "--num-epochs", type=int, default=35, help="Number of epochs to train.", ) parser.add_argument( @@ -230,10 +224,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -335,9 +326,7 @@ def compute_loss( decoding_graph = graph_compiler.compile(token_ids) dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, + nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1, ) ctc_loss = k2.ctc_loss( @@ -374,12 +363,12 @@ def compute_loss( assert loss.requires_grad == is_training info = LossRecord() - info['frames'] = supervision_segments[:, 2].sum().item() - info['ctc_loss'] = ctc_loss.detach().cpu().item() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: - info['att_loss'] = att_loss.detach().cpu().item() + info["att_loss"] = att_loss.detach().cpu().item() - info['loss'] = loss.detach().cpu().item() + info["loss"] = loss.detach().cpu().item() return loss, info @@ -410,7 +399,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value @@ -489,15 +478,9 @@ def train_one_epoch( if tb_writer is not None: loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, - "train/tot_", - params.batch_idx_train + tb_writer, "train/current_", params.batch_idx_train ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -509,17 +492,13 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train + tb_writer, "train/valid_", params.batch_idx_train ) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -563,10 +542,7 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", + params.lang_dir, device=device, sos_token="", eos_token="", ) logging.info("About to create model") @@ -607,9 +583,7 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: @@ -629,10 +603,7 @@ def run(rank, world_size, args): ) save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, + params=params, model=model, optimizer=optimizer, rank=rank, ) logging.info("Done!") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 016d51e2c..2b22e4e0f 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -58,10 +58,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -79,10 +76,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", + "--num-epochs", type=int, default=20, help="Number of epochs to train.", ) parser.add_argument( @@ -209,10 +203,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -312,9 +303,7 @@ def compute_loss( decoding_graph = graph_compiler.compile(texts) dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, + nnet_output, supervision_segments, allow_truncate=params.subsampling_factor - 1, ) loss = k2.ctc_loss( @@ -328,8 +317,8 @@ def compute_loss( assert loss.requires_grad == is_training info = LossRecord() - info['frames'] = supervision_segments[:, 2].sum().item() - info['loss'] = loss.detach().cpu().item() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() return loss, info @@ -363,7 +352,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch @@ -439,15 +428,9 @@ def train_one_epoch( if tb_writer is not None: loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, - "train/tot_", - params.batch_idx_train + tb_writer, "train/current_", params.batch_idx_train ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( @@ -458,17 +441,13 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") if tb_writer is not None: valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, + tb_writer, "train/valid_", params.batch_idx_train, ) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -526,9 +505,7 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) optimizer = optim.AdamW( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, + model.parameters(), lr=params.lr, weight_decay=params.weight_decay, ) scheduler = StepLR(optimizer, step_size=8, gamma=0.1) @@ -548,9 +525,7 @@ def run(rank, world_size, args): if tb_writer is not None: tb_writer.add_scalar( - "train/lr", - scheduler.get_last_lr()[0], - params.batch_idx_train, + "train/lr", scheduler.get_last_lr()[0], params.batch_idx_train, ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 582f3e822..f8e8538ca 100644 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -33,10 +33,7 @@ def get_parser(): ) parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", + "--world-size", type=int, default=1, help="Number of GPUs for DDP training.", ) parser.add_argument( @@ -54,10 +51,7 @@ def get_parser(): ) parser.add_argument( - "--num-epochs", - type=int, - default=15, - help="Number of epochs to train.", + "--num-epochs", type=int, default=15, help="Number of epochs to train.", ) parser.add_argument( @@ -187,10 +181,7 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, + filename, model=model, optimizer=optimizer, scheduler=scheduler, ) keys = [ @@ -287,16 +278,12 @@ def compute_loss( batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], dtype=torch.int32, ) decoding_graph = graph_compiler.compile(texts) - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - ) + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments,) loss = k2.ctc_loss( decoding_graph=decoding_graph, @@ -309,8 +296,8 @@ def compute_loss( assert loss.requires_grad == is_training info = LossRecord() - info['frames'] = supervision_segments[:, 2].sum().item() - info['loss'] = loss.detach().cpu().item() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() return loss, info @@ -344,7 +331,7 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch @@ -420,15 +407,9 @@ def train_one_epoch( if tb_writer is not None: loss_info.write_summary( - tb_writer, - "train/current_", - params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, - "train/tot_", - params.batch_idx_train + tb_writer, "train/current_", params.batch_idx_train ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( @@ -439,17 +420,13 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, validation {valid_info}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") if tb_writer is not None: valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, + tb_writer, "train/valid_", params.batch_idx_train, ) - loss_value = tot_loss['loss'] / tot_loss['frames'] + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -506,9 +483,7 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) optimizer = optim.SGD( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, + model.parameters(), lr=params.lr, weight_decay=params.weight_decay, ) if checkpoints: @@ -542,11 +517,7 @@ def run(rank, world_size, args): ) save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=None, - rank=rank, + params=params, model=model, optimizer=optimizer, scheduler=None, rank=rank, ) logging.info("Done!") diff --git a/icefall/utils.py b/icefall/utils.py index 876c926d9..2c551d884 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -107,9 +107,7 @@ def setup_logger( formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa log_filename = f"{log_filename}-{date_time}-{rank}" else: - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" log_filename = f"{log_filename}-{date_time}" os.makedirs(os.path.dirname(log_filename), exist_ok=True) @@ -236,9 +234,7 @@ def get_texts( return aux_labels.tolist() -def store_transcripts( - filename: Pathlike, texts: Iterable[Tuple[str, str]] -) -> None: +def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str]]) -> None: """Save predicted results and reference transcripts to a file. Args: @@ -369,19 +365,14 @@ def write_error_stats( ] ali = list(filter(lambda x: x != [[], []], ali)) ali = [ - [ - ERR if x == [] else " ".join(x), - ERR if y == [] else " ".join(y), - ] + [ERR if x == [] else " ".join(x), ERR if y == [] else " ".join(y),] for x, y in ali ] print( " ".join( ( - ref_word - if ref_word == hyp_word - else f"({ref_word}->{hyp_word})" + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali ) ), @@ -391,9 +382,7 @@ def write_error_stats( print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) - for count, (ref, hyp) in sorted( - [(v, k) for k, v in subs.items()], reverse=True - ): + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) @@ -407,9 +396,7 @@ def write_error_stats( print(f"{count} {hyp}", file=f) print("", file=f) - print( - "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f - ) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) for _, word, counts in sorted( [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True ): @@ -428,7 +415,7 @@ class LossRecord(collections.defaultdict): # makes undefined items default to int() which is zero. super(LossRecord, self).__init__(int) - def __add__(self, other: 'LossRecord') -> 'LossRecord': + def __add__(self, other: "LossRecord") -> "LossRecord": ans = LossRecord() for k, v in self.items(): ans[k] = v @@ -436,19 +423,19 @@ class LossRecord(collections.defaultdict): ans[k] = ans[k] + v return ans - def __mul__(self, alpha: float) -> 'LossRecord': + def __mul__(self, alpha: float) -> "LossRecord": ans = LossRecord() for k, v in self.items(): ans[k] = v * alpha return ans def __str__(self) -> str: - ans = '' + ans = "" for k, v in self.norm_items(): - norm_value = '%.4g' % v - ans += (str(k) + '=' + str(norm_value) + ', ') - frames = str(self['frames']) - ans += 'over ' + frames + ' frames.' + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + frames = str(self["frames"]) + ans += "over " + frames + " frames." return ans def norm_items(self) -> List[Tuple[str, float]]: @@ -456,10 +443,10 @@ class LossRecord(collections.defaultdict): Returns a list of pairs, like: [('ctc_loss', 0.1), ('att_loss', 0.07)] """ - num_frames = self['frames'] if 'frames' in self else 1 + num_frames = self["frames"] if "frames" in self else 1 ans = [] for k, v in self.items(): - if k != 'frames': + if k != "frames": norm_value = float(v) / num_frames ans.append((k, norm_value)) return ans @@ -470,17 +457,13 @@ class LossRecord(collections.defaultdict): all processes get the total. """ keys = sorted(self.keys()) - s = torch.tensor([float(self[k]) for k in keys], - device=device) + s = torch.tensor([float(self[k]) for k in keys], device=device) dist.all_reduce(s, op=dist.ReduceOp.SUM) for k, v in zip(keys, s.cpu().tolist()): self[k] = v def write_summary( - self, - tb_writer: SummaryWriter, - prefix: str, - batch_idx: int, + self, tb_writer: SummaryWriter, prefix: str, batch_idx: int, ) -> None: """Add logging information to a TensorBoard writer.