Fix style issues.

This commit is contained in:
Fangjun Kuang 2022-04-21 11:49:52 +08:00
parent e4d45adf5a
commit 3d0474c986
2 changed files with 32 additions and 113 deletions

View File

@ -18,8 +18,8 @@
import k2
import torch
import torch.nn as nn
import torchaudio
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
@ -51,9 +51,10 @@ class Transducer(nn.Module):
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)

View File

@ -21,22 +21,21 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless2/train.py \
./transducer_stateless3/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir transducer_stateless3/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
./transducer_stateless3/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir transducer_stateless3/exp \
--full-libri 1 \
--max-duration 550
@ -138,7 +137,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="transducer_stateless3/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -156,7 +155,8 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need to be "
"changed.",
)
parser.add_argument(
@ -183,40 +183,6 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--seed",
type=int,
@ -255,13 +221,6 @@ def get_parser():
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
@ -318,7 +277,7 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"valid_interval": 3000, # For the 100h subset, use 1600
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
@ -506,7 +465,6 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -523,8 +481,6 @@ def compute_loss(
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device
feature = batch["inputs"]
@ -540,27 +496,10 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training
@ -574,8 +513,6 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info
@ -622,7 +559,6 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -646,8 +582,6 @@ def train_one_epoch(
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
@ -670,25 +604,22 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
loss.backward()
scheduler.step_batch(params.batch_idx_train)
optimizer.step()
if params.print_diagnostics and batch_idx == 5:
return
@ -706,7 +637,6 @@ def train_one_epoch(
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
@ -883,11 +813,6 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch)
@ -906,7 +831,6 @@ def run(rank, world_size, args):
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
@ -922,7 +846,6 @@ def run(rank, world_size, args):
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
@ -949,21 +872,16 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
)
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(