mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Fix style issues.
This commit is contained in:
parent
e4d45adf5a
commit
3d0474c986
@ -18,8 +18,8 @@
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
@ -51,9 +51,10 @@ class Transducer(nn.Module):
|
|||||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||||
It should contain one attribute: `blank_id`.
|
It should contain one attribute: `blank_id`.
|
||||||
joiner:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
It has two inputs with shapes: (N, T, encoder_dim) and
|
||||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
(N, U, decoder_dim).
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
Its output shape is (N, T, U, vocab_size). Note that its output
|
||||||
|
contains unnormalized probs, i.e., not processed by log-softmax.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
@ -21,22 +21,21 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./transducer_stateless3/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir transducer_stateless3/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./transducer_stateless3/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--use_fp16 1 \
|
--exp-dir transducer_stateless3/exp \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
@ -138,7 +137,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="transducer_stateless3/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -156,7 +155,8 @@ def get_parser():
|
|||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.003,
|
default=0.003,
|
||||||
help="The initial learning rate. This value should not need to be changed.",
|
help="The initial learning rate. This value should not need to be "
|
||||||
|
"changed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -183,40 +183,6 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--prune-range",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
|
||||||
"we are using to compute the loss",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lm-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.25,
|
|
||||||
help="The scale to smooth the loss with lm "
|
|
||||||
"(output of prediction network) part.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--am-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="The scale to smooth the loss with am (output of encoder network)"
|
|
||||||
"part.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--simple-loss-scale",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="To get pruning ranges, we will calculate a simple version"
|
|
||||||
"loss(joiner is just addition), this simple loss also uses for"
|
|
||||||
"training (as a regularization item). We will scale the simple loss"
|
|
||||||
"with this parameter before adding to the final loss.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -255,13 +221,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-fp16",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to use half precision training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -318,7 +277,7 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 1600
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
@ -506,7 +465,6 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
warmup: float = 1.0,
|
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -523,8 +481,6 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
warmup: a floating point value which increases throughout training;
|
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
|
||||||
"""
|
"""
|
||||||
device = model.device
|
device = model.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -540,27 +496,10 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
|
||||||
am_scale=params.am_scale,
|
|
||||||
lm_scale=params.lm_scale,
|
|
||||||
warmup=warmup,
|
|
||||||
)
|
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
|
||||||
# overwhelming the simple_loss and causing it to diverge,
|
|
||||||
# in case it had not fully learned the alignment yet.
|
|
||||||
pruned_loss_scale = (
|
|
||||||
0.0
|
|
||||||
if warmup < 1.0
|
|
||||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
|
||||||
)
|
|
||||||
loss = (
|
|
||||||
params.simple_loss_scale * simple_loss
|
|
||||||
+ pruned_loss_scale * pruned_loss
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
@ -574,8 +513,6 @@ def compute_loss(
|
|||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -622,7 +559,6 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -646,8 +582,6 @@ def train_one_epoch(
|
|||||||
Dataloader for the training dataset.
|
Dataloader for the training dataset.
|
||||||
valid_dl:
|
valid_dl:
|
||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
scaler:
|
|
||||||
The scaler used for mix precision training.
|
|
||||||
tb_writer:
|
tb_writer:
|
||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
world_size:
|
world_size:
|
||||||
@ -670,25 +604,22 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
loss, loss_info = compute_loss(
|
||||||
loss, loss_info = compute_loss(
|
params=params,
|
||||||
params=params,
|
model=model,
|
||||||
model=model,
|
sp=sp,
|
||||||
sp=sp,
|
batch=batch,
|
||||||
batch=batch,
|
is_training=True,
|
||||||
is_training=True,
|
)
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
|
||||||
)
|
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
|
||||||
scaler.step(optimizer)
|
|
||||||
scaler.update()
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
@ -706,7 +637,6 @@ def train_one_epoch(
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
scaler=scaler,
|
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
del params.cur_batch_idx
|
del params.cur_batch_idx
|
||||||
@ -883,11 +813,6 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
|
||||||
logging.info("Loading grad scaler state dict")
|
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
scheduler.step_epoch(epoch)
|
scheduler.step_epoch(epoch)
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
@ -906,7 +831,6 @@ def run(rank, world_size, args):
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -922,7 +846,6 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
scaler=scaler,
|
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -949,21 +872,16 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
loss, _ = compute_loss(
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
params=params,
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
model=model,
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
sp=sp,
|
||||||
loss, _ = compute_loss(
|
batch=batch,
|
||||||
params=params,
|
is_training=True,
|
||||||
model=model,
|
)
|
||||||
sp=sp,
|
optimizer.zero_grad()
|
||||||
batch=batch,
|
|
||||||
is_training=True,
|
|
||||||
warmup=0.0,
|
|
||||||
)
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "CUDA out of memory" in str(e):
|
if "CUDA out of memory" in str(e):
|
||||||
logging.error(
|
logging.error(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user