check some files

This commit is contained in:
luomingshuang 2022-04-11 20:37:19 +08:00
parent 05fd40ba68
commit 187d59d59b
2 changed files with 39 additions and 35 deletions

View File

@ -7,8 +7,7 @@ per-file-ignores =
egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501,
egs/tedlium3/ASR/*/conformer.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/model.py: E501,
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605

View File

@ -36,16 +36,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import logging
import math
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import optim # from .
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
@ -56,26 +55,23 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eve, Eden
from optim import Eden, Eve
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall import diagnostics
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def get_parser():
parser = argparse.ArgumentParser(
@ -158,7 +154,7 @@ def get_parser():
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this."""
We suggest not to change this.""",
)
parser.add_argument(
@ -166,7 +162,7 @@ def get_parser():
type=float,
default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases.
"""
""",
)
parser.add_argument(
@ -318,7 +314,7 @@ def get_params() -> AttributeDict:
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
}
)
@ -489,7 +485,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
warmup: float = 1.0
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -536,18 +532,24 @@ def compute_loss(
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (0.0 if warmup < 1.0 else
(0.1 if warmup > 1.0 and warmup < 2.0 else
1.0))
loss = (params.simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss)
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -650,7 +652,7 @@ def train_one_epoch(
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step)
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -665,8 +667,10 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5:
return
if (params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0):
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
@ -695,7 +699,9 @@ def train_one_epoch(
)
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
@ -784,18 +790,19 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Eve(
model.parameters(),
lr=params.initial_lr)
optimizer = Eve(model.parameters(), lr=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None:
if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
@ -805,7 +812,6 @@ def run(rank, world_size, args):
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
@ -855,7 +861,6 @@ def run(rank, world_size, args):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = scheduler.get_last_lr()[0]
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
@ -919,7 +924,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup = 0.0
warmup=0.0,
)
loss.backward()
optimizer.step()