check some files

This commit is contained in:
luomingshuang 2022-04-11 20:37:19 +08:00
parent 05fd40ba68
commit 187d59d59b
2 changed files with 39 additions and 35 deletions

View File

@ -7,8 +7,7 @@ per-file-ignores =
egs/librispeech/ASR/*/conformer.py: E501, egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501, egs/aishell/ASR/*/conformer.py: E501,
egs/tedlium3/ASR/*/conformer.py: E501, egs/tedlium3/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py: E501, egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/model.py: E501,
# invalid escape sequence (cause by tex formular), W605 # invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605 icefall/utils.py: E501, W605

View File

@ -36,16 +36,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import math
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import optim # from .
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
@ -56,26 +55,23 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eve, Eden from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall import diagnostics from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import ( LRSchedulerType = Union[
AttributeDict, torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
MetricsTracker, ]
setup_logger,
str2bool,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -158,7 +154,7 @@ def get_parser():
type=float, type=float,
default=5000, default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases. help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this.""" We suggest not to change this.""",
) )
parser.add_argument( parser.add_argument(
@ -166,7 +162,7 @@ def get_parser():
type=float, type=float,
default=6, default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases. help="""Number of epochs that affects how rapidly the learning rate decreases.
""" """,
) )
parser.add_argument( parser.add_argument(
@ -318,7 +314,7 @@ def get_params() -> AttributeDict:
# parameters for joiner # parameters for joiner
"joiner_dim": 512, "joiner_dim": 512,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -489,7 +485,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float = 1.0 warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -536,18 +532,24 @@ def compute_loss(
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = (0.0 if warmup < 1.0 else pruned_loss_scale = (
(0.1 if warmup > 1.0 and warmup < 2.0 else 0.0
1.0)) if warmup < 1.0
loss = (params.simple_loss_scale * simple_loss + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
pruned_loss_scale * pruned_loss) )
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -650,7 +652,7 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step) warmup=(params.batch_idx_train / params.model_warm_step),
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -665,8 +667,10 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if (params.batch_idx_train > 0 if (
and params.batch_idx_train % params.save_every_n == 0): params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx( save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir, out_dir=params.exp_dir,
@ -695,7 +699,9 @@ def train_one_epoch(
) )
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train) tb_writer.add_scalar(
"train/learning_rate", params.batch_idx_train
)
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
@ -784,18 +790,19 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
model.device = device model.device = device
optimizer = Eve( optimizer = Eve(model.parameters(), lr=params.initial_lr)
model.parameters(),
lr=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict") logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None: if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict") logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
@ -805,7 +812,6 @@ def run(rank, world_size, args):
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
@ -855,7 +861,6 @@ def run(rank, world_size, args):
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
@ -919,7 +924,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup = 0.0 warmup=0.0,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()