From 187d59d59bbdb5044d12dc468ad593954f601231 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Mon, 11 Apr 2022 20:37:19 +0800 Subject: [PATCH] check some files --- .flake8 | 3 +- .../ASR/pruned_transducer_stateless2/train.py | 71 ++++++++++--------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/.flake8 b/.flake8 index 7e80abd3e..5b3c444b8 100644 --- a/.flake8 +++ b/.flake8 @@ -7,8 +7,7 @@ per-file-ignores = egs/librispeech/ASR/*/conformer.py: E501, egs/aishell/ASR/*/conformer.py: E501, egs/tedlium3/ASR/*/conformer.py: E501, - egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py: E501, - egs/librispeech/ASR/pruned_transducer_stateless2/model.py: E501, + egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501, # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index b9ea0def6..f3f37b1bc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -36,16 +36,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging -import math import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 +import optim import sentencepiece as spm import torch -import optim # from . import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -56,26 +55,23 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer -from optim import Eve, Eden +from optim import Eden, Eve from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info -from icefall import diagnostics +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -from icefall.utils import ( - AttributeDict, - MetricsTracker, - setup_logger, - str2bool, -) +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def get_parser(): parser = argparse.ArgumentParser( @@ -158,7 +154,7 @@ def get_parser(): type=float, default=5000, help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""" + We suggest not to change this.""", ) parser.add_argument( @@ -166,7 +162,7 @@ def get_parser(): type=float, default=6, help="""Number of epochs that affects how rapidly the learning rate decreases. - """ + """, ) parser.add_argument( @@ -318,7 +314,7 @@ def get_params() -> AttributeDict: # parameters for joiner "joiner_dim": 512, # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate + "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } ) @@ -489,7 +485,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - warmup: float = 1.0 + warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -536,18 +532,24 @@ def compute_loss( # for the same amount of time (model_warm_step), to avoid # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. - pruned_loss_scale = (0.0 if warmup < 1.0 else - (0.1 if warmup > 1.0 and warmup < 2.0 else - 1.0)) - loss = (params.simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss) + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -650,7 +652,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step) + warmup=(params.batch_idx_train / params.model_warm_step), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -665,8 +667,10 @@ def train_one_epoch( if params.print_diagnostics and batch_idx == 5: return - if (params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0): + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, @@ -695,7 +699,9 @@ def train_one_epoch( ) if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", params.batch_idx_train + ) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train @@ -784,18 +790,19 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device - optimizer = Eve( - model.parameters(), - lr=params.initial_lr) + optimizer = Eve(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") optimizer.load_state_dict(checkpoints["optimizer"]) - if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None: + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) @@ -805,7 +812,6 @@ def run(rank, world_size, args): ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) - librispeech = LibriSpeechAsrDataModule(args) train_cuts = librispeech.train_clean_100_cuts() @@ -855,7 +861,6 @@ def run(rank, world_size, args): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - cur_lr = scheduler.get_last_lr()[0] if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -919,7 +924,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup = 0.0 + warmup=0.0, ) loss.backward() optimizer.step()