mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
check some files
This commit is contained in:
parent
05fd40ba68
commit
187d59d59b
3
.flake8
3
.flake8
@ -7,8 +7,7 @@ per-file-ignores =
|
|||||||
egs/librispeech/ASR/*/conformer.py: E501,
|
egs/librispeech/ASR/*/conformer.py: E501,
|
||||||
egs/aishell/ASR/*/conformer.py: E501,
|
egs/aishell/ASR/*/conformer.py: E501,
|
||||||
egs/tedlium3/ASR/*/conformer.py: E501,
|
egs/tedlium3/ASR/*/conformer.py: E501,
|
||||||
egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py: E501,
|
egs/librispeech/ASR/pruned_transducer_stateless2/*.py: E501,
|
||||||
egs/librispeech/ASR/pruned_transducer_stateless2/model.py: E501,
|
|
||||||
|
|
||||||
# invalid escape sequence (cause by tex formular), W605
|
# invalid escape sequence (cause by tex formular), W605
|
||||||
icefall/utils.py: E501, W605
|
icefall/utils.py: E501, W605
|
||||||
|
@ -36,16 +36,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import optim
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import optim # from .
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
@ -56,26 +55,23 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eve, Eden
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall import diagnostics
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
from icefall.utils import (
|
LRSchedulerType = Union[
|
||||||
AttributeDict,
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
MetricsTracker,
|
]
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -158,7 +154,7 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=5000,
|
default=5000,
|
||||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||||
We suggest not to change this."""
|
We suggest not to change this.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -166,7 +162,7 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=6,
|
default=6,
|
||||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -318,7 +314,7 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for joiner
|
# parameters for joiner
|
||||||
"joiner_dim": 512,
|
"joiner_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -489,7 +485,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
warmup: float = 1.0
|
warmup: float = 1.0,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -536,18 +532,24 @@ def compute_loss(
|
|||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (0.0 if warmup < 1.0 else
|
pruned_loss_scale = (
|
||||||
(0.1 if warmup > 1.0 and warmup < 2.0 else
|
0.0
|
||||||
1.0))
|
if warmup < 1.0
|
||||||
loss = (params.simple_loss_scale * simple_loss +
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
pruned_loss_scale * pruned_loss)
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -650,7 +652,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step)
|
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -665,8 +667,10 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (params.batch_idx_train > 0
|
if (
|
||||||
and params.batch_idx_train % params.save_every_n == 0):
|
params.batch_idx_train > 0
|
||||||
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
|
):
|
||||||
params.cur_batch_idx = batch_idx
|
params.cur_batch_idx = batch_idx
|
||||||
save_checkpoint_with_global_batch_idx(
|
save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
@ -695,7 +699,9 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train)
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
@ -784,18 +790,19 @@ def run(rank, world_size, args):
|
|||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
optimizer = Eve(
|
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||||
model.parameters(),
|
|
||||||
lr=params.initial_lr)
|
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
logging.info("Loading optimizer state dict")
|
logging.info("Loading optimizer state dict")
|
||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
if checkpoints and "scheduler" in checkpoints and checkpoints["scheduler"] is not None:
|
if (
|
||||||
|
checkpoints
|
||||||
|
and "scheduler" in checkpoints
|
||||||
|
and checkpoints["scheduler"] is not None
|
||||||
|
):
|
||||||
logging.info("Loading scheduler state dict")
|
logging.info("Loading scheduler state dict")
|
||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
@ -805,7 +812,6 @@ def run(rank, world_size, args):
|
|||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
@ -855,7 +861,6 @@ def run(rank, world_size, args):
|
|||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
@ -919,7 +924,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup = 0.0
|
warmup=0.0,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user