Fix style issues.

This commit is contained in:
Fangjun Kuang 2022-04-21 11:49:52 +08:00
parent e4d45adf5a
commit 3d0474c986
2 changed files with 32 additions and 113 deletions

View File

@ -18,8 +18,8 @@
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos
@ -51,9 +51,10 @@ class Transducer(nn.Module):
is (N, U) and its output shape is (N, U, decoder_dim). is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`. It should contain one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). It has two inputs with shapes: (N, T, encoder_dim) and
Its output shape is (N, T, U, vocab_size). Note that its output contains (N, U, decoder_dim).
unnormalized probs, i.e., not processed by log-softmax. Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)

View File

@ -21,22 +21,21 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless2/train.py \ ./transducer_stateless3/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir transducer_stateless3/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless2/train.py \ ./transducer_stateless3/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--use_fp16 1 \ --exp-dir transducer_stateless3/exp \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
@ -138,7 +137,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="transducer_stateless3/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -156,7 +155,8 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need to be changed.", help="The initial learning rate. This value should not need to be "
"changed.",
) )
parser.add_argument( parser.add_argument(
@ -183,40 +183,6 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -255,13 +221,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser return parser
@ -318,7 +277,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 1600
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
@ -506,7 +465,6 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -523,8 +481,6 @@ def compute_loss(
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device device = model.device
feature = batch["inputs"] feature = batch["inputs"]
@ -540,27 +496,10 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -574,8 +513,6 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info return loss, info
@ -622,7 +559,6 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -646,8 +582,6 @@ def train_one_epoch(
Dataloader for the training dataset. Dataloader for the training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -670,25 +604,22 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss(
loss, loss_info = compute_loss( params=params,
params=params, model=model,
model=model, sp=sp,
sp=sp, batch=batch,
batch=batch, is_training=True,
is_training=True, )
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward()
scheduler.step_batch(params.batch_idx_train)
optimizer.step()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
@ -706,7 +637,6 @@ def train_one_epoch(
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx del params.cur_batch_idx
@ -883,11 +813,6 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch) scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
@ -906,7 +831,6 @@ def run(rank, world_size, args):
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
@ -922,7 +846,6 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
@ -949,21 +872,16 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero loss, _ = compute_loss(
# (i.e. are not remembered by the decaying-average in adam), because params=params,
# we want to avoid these params being subject to shrinkage in adam. model=model,
with torch.cuda.amp.autocast(enabled=params.use_fp16): sp=sp,
loss, _ = compute_loss( batch=batch,
params=params, is_training=True,
model=model, )
sp=sp, optimizer.zero_grad()
batch=batch,
is_training=True,
warmup=0.0,
)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad()
except RuntimeError as e: except RuntimeError as e:
if "CUDA out of memory" in str(e): if "CUDA out of memory" in str(e):
logging.error( logging.error(