black format

This commit is contained in:
wangtiance 2023-10-26 18:34:08 +08:00
parent 7c7c9da05a
commit 9ec52fdb77
7 changed files with 131 additions and 125 deletions

View File

@ -264,8 +264,8 @@ class LibriSpeechAsrDataModule:
num_feature_masks=2, num_feature_masks=2,
features_mask_size=5, features_mask_size=5,
num_frame_masks=10, num_frame_masks=10,
frames_mask_size=5, frames_mask_size=5,
p=0.5, p=0.5,
) )
) )
else: else:

View File

@ -373,7 +373,6 @@ def decode_one_batch(
# lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] # lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
lm_scale_list = [0.6, 0.7, 0.8, 0.9] lm_scale_list = [0.6, 0.7, 0.8, 0.9]
if params.decoding_method == "nbest-rescoring": if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list( best_path_dict = rescore_with_n_best_list(
lattice=lattice, lattice=lattice,
@ -507,9 +506,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -576,7 +573,9 @@ def main():
params.blank_id = 0 params.blank_id = 0
if params.decoding_method == "ctc-decoding": if params.decoding_method == "ctc-decoding":
assert "lang_bpe" in str(params.lang_dir), "ctc-decoding only supports BPE lexicons." assert "lang_bpe" in str(
params.lang_dir
), "ctc-decoding only supports BPE lexicons."
HLG = None HLG = None
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,

View File

@ -39,7 +39,6 @@ from icefall.utils import (
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -493,7 +492,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)

View File

@ -26,7 +26,7 @@ from torch import Tensor, nn
class Conv1dNet(EncoderInterface): class Conv1dNet(EncoderInterface):
""" """
1D Convolution network with causal squeeze and excitation 1D Convolution network with causal squeeze and excitation
module and optional skip connections. module and optional skip connections.
Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride. Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride.
@ -34,11 +34,11 @@ class Conv1dNet(EncoderInterface):
Args: Args:
output_dim (int): Number of output channels of the last layer. output_dim (int): Number of output channels of the last layer.
input_dim (int): Number of input features input_dim (int): Number of input features
conv_layers (int): Number of convolution layers, conv_layers (int): Number of convolution layers,
excluding the subsampling layers. excluding the subsampling layers.
channels (int): Number of output channels for each layer, channels (int): Number of output channels for each layer,
except the last layer. except the last layer.
subsampling_factor (int): The subsampling factor for the model. subsampling_factor (int): The subsampling factor for the model.
skip_add (bool): Whether to use skip connection for each convolution layer. skip_add (bool): Whether to use skip connection for each convolution layer.
dscnn (bool): Whether to use depthwise-separated convolution. dscnn (bool): Whether to use depthwise-separated convolution.
activation (str): Activation function type. activation (str): Activation function type.
@ -53,7 +53,7 @@ class Conv1dNet(EncoderInterface):
subsampling_factor: int = 4, subsampling_factor: int = 4,
skip_add: bool = False, skip_add: bool = False,
dscnn: bool = True, dscnn: bool = True,
activation: str = 'relu', activation: str = "relu",
) -> None: ) -> None:
super().__init__() super().__init__()
assert subsampling_factor == 4, "Only support subsampling = 4" assert subsampling_factor == 4, "Only support subsampling = 4"
@ -62,10 +62,12 @@ class Conv1dNet(EncoderInterface):
self.skip_add = skip_add self.skip_add = skip_add
# 80ms latency for subsample_layer # 80ms latency for subsample_layer
self.subsample_layer = nn.Sequential( self.subsample_layer = nn.Sequential(
conv1d_bn_block(input_dim, channels, 9, conv1d_bn_block(
stride=2, activation=activation, dscnn=dscnn), input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn
conv1d_bn_block(channels, channels, 5, ),
stride=2, activation=activation, dscnn=dscnn), conv1d_bn_block(
channels, channels, 5, stride=2, activation=activation, dscnn=dscnn
),
) )
self.conv_blocks = nn.ModuleList() self.conv_blocks = nn.ModuleList()
@ -82,13 +84,15 @@ class Conv1dNet(EncoderInterface):
3, 3,
activation=activation, activation=activation,
dscnn=dscnn, dscnn=dscnn,
causal=ly % 2), causal=ly % 2,
CausalSqueezeExcite1d(cout[ly], 16, 30) ),
CausalSqueezeExcite1d(cout[ly], 16, 30),
) )
) )
def forward( def forward(
self, x: torch.Tensor, self,
x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
@ -104,24 +108,25 @@ class Conv1dNet(EncoderInterface):
- lengths, a tensor of shape (batch_size,) containing the number - lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding. of frames in `embeddings` before padding.
""" """
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.subsample_layer(x) x = self.subsample_layer(x)
for idx, layer in enumerate(self.conv_blocks): for idx, layer in enumerate(self.conv_blocks):
if self.skip_add and 0 < idx < self.conv_layers-1: if self.skip_add and 0 < idx < self.conv_layers - 1:
x = layer(x) + x x = layer(x) + x
else: else:
x = layer(x) x = layer(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
lengths = x_lens >> 2 lengths = x_lens >> 2
return x, lengths return x, lengths
def get_activation(name: str, def get_activation(
channels: int, name: str,
channel_dim: int = -1, channels: int,
min_val: int = 0, channel_dim: int = -1,
max_val: int = 1, min_val: int = 0,
) -> nn.Module: max_val: int = 1,
) -> nn.Module:
""" """
Get activation function from name in string. Get activation function from name in string.
@ -140,44 +145,42 @@ def get_activation(name: str,
""" """
act_layer = nn.Identity() act_layer = nn.Identity()
name = name.lower() name = name.lower()
if name == 'prelu': if name == "prelu":
act_layer = nn.PReLU(channels) act_layer = nn.PReLU(channels)
elif name == 'relu': elif name == "relu":
act_layer = nn.ReLU() act_layer = nn.ReLU()
elif name == 'relu6': elif name == "relu6":
act_layer = nn.ReLU6() act_layer = nn.ReLU6()
elif name == 'hardtanh': elif name == "hardtanh":
act_layer = nn.Hardtanh(min_val, max_val) act_layer = nn.Hardtanh(min_val, max_val)
elif name in ['swish', 'silu']: elif name in ["swish", "silu"]:
act_layer = nn.SiLU() act_layer = nn.SiLU()
elif name == 'elu': elif name == "elu":
act_layer = nn.ELU() act_layer = nn.ELU()
elif name == 'doubleswish': elif name == "doubleswish":
act_layer = nn.Sequential( act_layer = nn.Sequential(
ActivationBalancer( ActivationBalancer(num_channels=channels, channel_dim=channel_dim),
num_channels=channels,
channel_dim=channel_dim),
DoubleSwish(), DoubleSwish(),
) )
elif name == '': elif name == "":
act_layer = nn.Identity() act_layer = nn.Identity()
else: else:
raise Exception(f'Unknown activation function: {name}') raise Exception(f"Unknown activation function: {name}")
return act_layer return act_layer
class CausalSqueezeExcite1d(nn.Module): class CausalSqueezeExcite1d(nn.Module):
""" """
Causal squeeze and excitation module with input and output shape Causal squeeze and excitation module with input and output shape
(batch, channels, time). The global average pooling in the original (batch, channels, time). The global average pooling in the original
SE module is replaced by a causal filter, so SE module is replaced by a causal filter, so
the layer does not introduce any algorithmic latency. the layer does not introduce any algorithmic latency.
Args: Args:
channels (int): Number of channels channels (int): Number of channels
reduction (int): channel reduction rate reduction (int): channel reduction rate
context_window (int): Context window size for the moving average operation. context_window (int): Context window size for the moving average operation.
For EMA, the smoothing factor is 1 / context_window. For EMA, the smoothing factor is 1 / context_window.
""" """
@ -205,8 +208,8 @@ class CausalSqueezeExcite1d(nn.Module):
self.ema_matrix_size = 0 self.ema_matrix_size = 0
def _precompute_ema_matrix(self, N: int, device: torch.device): def _precompute_ema_matrix(self, N: int, device: torch.device):
a = 1.0 / self.context_window # smoothing factor a = 1.0 / self.context_window # smoothing factor
w = [[(1-a)**k * a for k in range(n, n-N, -1)] for n in range(N)] w = [[(1 - a) ** k * a for k in range(n, n - N, -1)] for n in range(N)]
w = torch.tensor(w).to(device).tril() w = torch.tensor(w).to(device).tril()
w[:, 0] *= self.context_window w[:, 0] *= self.context_window
self.ema_matrix = w.T self.ema_matrix = w.T
@ -218,8 +221,8 @@ class CausalSqueezeExcite1d(nn.Module):
y[t] = (1-a) * y[t-1] + a * x[t] y[t] = (1-a) * y[t-1] + a * x[t]
where a = 1 / self.context_window is the smoothing factor. where a = 1 / self.context_window is the smoothing factor.
For training, the iterative version is too slow. A better way is For training, the iterative version is too slow. A better way is
to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication. to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication.
The weight matrix can be precomputed if the smoothing factor is fixed. The weight matrix can be precomputed if the smoothing factor is fixed.
""" """
if self.training: if self.training:
@ -234,29 +237,29 @@ class CausalSqueezeExcite1d(nn.Module):
y = torch.empty_like(x) y = torch.empty_like(x)
y[:, :, 0] = x[:, :, 0] y[:, :, 0] = x[:, :, 0]
for t in range(1, y.shape[-1]): for t in range(1, y.shape[-1]):
y[:, :, t] = (1-a) * y[:, :, t-1] + a * x[:, :, t] y[:, :, t] = (1 - a) * y[:, :, t - 1] + a * x[:, :, t]
return y return y
def moving_avg(self, x: Tensor) -> Tensor: def moving_avg(self, x: Tensor) -> Tensor:
""" """
Simple moving average with context_window as window size. Simple moving average with context_window as window size.
""" """
y = torch.empty_like(x) y = torch.empty_like(x)
k = min(x.shape[2], self.context_window) k = min(x.shape[2], self.context_window)
w = [[1/n] * n + [0] * (k-n-1) for n in range(1, k)] w = [[1 / n] * n + [0] * (k - n - 1) for n in range(1, k)]
w = torch.tensor(w, device=x.device) w = torch.tensor(w, device=x.device)
y[:, :, :k-1] = torch.matmul(x[:, :, :k-1], w.T) y[:, :, : k - 1] = torch.matmul(x[:, :, : k - 1], w.T)
y[:, :, k-1:] = F.avg_pool1d(x, k, 1) y[:, :, k - 1 :] = F.avg_pool1d(x, k, 1)
return y return y
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
assert len(x.shape) == 3, "Input is not a 3D tensor!" assert len(x.shape) == 3, "Input is not a 3D tensor!"
y = self.exponential_moving_avg(x) y = self.exponential_moving_avg(x)
y = y.permute(0, 2, 1) # make channel last for squeeze op y = y.permute(0, 2, 1) # make channel last for squeeze op
y = self.act1(self.linear1(y)) y = self.act1(self.linear1(y))
y = self.act2(self.linear2(y)) y = self.act2(self.linear2(y))
y = y.permute(0, 2, 1) # back to original shape y = y.permute(0, 2, 1) # back to original shape
y = x * y y = x * y
return y return y
@ -267,12 +270,12 @@ def conv1d_bn_block(
kernel_size: int = 3, kernel_size: int = 3,
stride: int = 1, stride: int = 1,
dilation: int = 1, dilation: int = 1,
activation: str = 'relu', activation: str = "relu",
dscnn: bool = False, dscnn: bool = False,
causal: bool = False, causal: bool = False,
) -> nn.Sequential: ) -> nn.Sequential:
""" """
Conv1d - batchnorm - activation block. Conv1d - batchnorm - activation block.
If kernel size is even, output length = input length + 1. If kernel size is even, output length = input length + 1.
Otherwise, output and input lengths are equal. Otherwise, output and input lengths are equal.
@ -296,16 +299,19 @@ def conv1d_bn_block(
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
groups=in_channels, groups=in_channels,
bias=False) if causal else bias=False,
nn.Conv1d( )
if causal
else nn.Conv1d(
in_channels, in_channels,
in_channels, in_channels,
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size//2) * dilation, padding=(kernel_size // 2) * dilation,
dilation=dilation, dilation=dilation,
groups=in_channels, groups=in_channels,
bias=False), bias=False,
),
nn.BatchNorm1d(in_channels), nn.BatchNorm1d(in_channels),
get_activation(activation, in_channels), get_activation(activation, in_channels),
nn.Conv1d(in_channels, out_channels, 1, bias=False), nn.Conv1d(in_channels, out_channels, 1, bias=False),
@ -320,15 +326,18 @@ def conv1d_bn_block(
kernel_size, kernel_size,
stride=stride, stride=stride,
dilation=dilation, dilation=dilation,
bias=False) if causal else bias=False,
nn.Conv1d( )
if causal
else nn.Conv1d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
stride=stride, stride=stride,
padding=(kernel_size//2) * dilation, padding=(kernel_size // 2) * dilation,
dilation=dilation, dilation=dilation,
bias=False), bias=False,
),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
get_activation(activation, out_channels), get_activation(activation, out_channels),
) )
@ -336,7 +345,7 @@ def conv1d_bn_block(
class CausalConv1d(nn.Module): class CausalConv1d(nn.Module):
""" """
Causal convolution with padding automatically chosen to match input/output length. Causal convolution with padding automatically chosen to match input/output length.
""" """
def __init__( def __init__(
@ -352,11 +361,19 @@ class CausalConv1d(nn.Module):
super(CausalConv1d, self).__init__() super(CausalConv1d, self).__init__()
assert kernel_size > 2 assert kernel_size > 2
self.padding = dilation*(kernel_size-1) self.padding = dilation * (kernel_size - 1)
self.stride = stride self.stride = stride
self.conv = nn.Conv1d(in_channels, out_channels, self.conv = nn.Conv1d(
kernel_size, stride, self.padding, dilation, groups, bias=bias) in_channels,
out_channels,
kernel_size,
stride,
self.padding,
dilation,
groups,
bias=bias,
)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.conv(x)[:, :, :-self.padding // self.stride] return self.conv(x)[:, :, : -self.padding // self.stride]

View File

@ -78,9 +78,12 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from icefall.checkpoint import (average_checkpoints, from icefall.checkpoint import (
average_checkpoints_with_averaged_model, average_checkpoints,
find_checkpoints, load_checkpoint) average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import UniqLexicon from icefall.lexicon import UniqLexicon
from icefall.utils import str2bool from icefall.utils import str2bool
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -188,7 +191,7 @@ def main():
if "lang_bpe" in str(params.lang_dir): if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model') sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()

View File

@ -228,7 +228,7 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model') sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")

View File

@ -79,15 +79,12 @@ from icefall.utils import (
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def set_batch_count( def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
model: Union[nn.Module, DDP],
batch_count: float
) -> None:
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
model = model.module model = model.module
for module in model.modules(): for module in model.modules():
if hasattr(module, 'batch_count'): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
@ -140,7 +137,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="""Use skip connection in the encoder. help="""Use skip connection in the encoder.
""", """,
) )
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -226,8 +223,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -250,8 +246,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -450,7 +445,7 @@ def get_transducer_model(params: AttributeDict) -> Transducer:
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
if is_module_available('thop'): if is_module_available("thop"):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda:0") device = torch.device("cuda:0")
else: else:
@ -459,13 +454,14 @@ def get_transducer_model(params: AttributeDict) -> Transducer:
x = torch.zeros((1, 1000, params.feature_dim)).to(device) x = torch.zeros((1, 1000, params.feature_dim)).to(device)
x_lens = torch.Tensor([1000]).int().to(device) x_lens = torch.Tensor([1000]).int().to(device)
from thop import clever_format, profile from thop import clever_format, profile
m = copy.deepcopy(encoder) m = copy.deepcopy(encoder)
m = m.to(device) m = m.to(device)
ops, _ = clever_format(profile(m, (x, x_lens), verbose=False)) ops, _ = clever_format(profile(m, (x, x_lens), verbose=False))
logging.info(f'Encoder MAC ops for 10 seconds of audio is {ops}') logging.info(f"Encoder MAC ops for 10 seconds of audio is {ops}")
else: else:
logging.info('You can install thop to calculate the number of ops.') logging.info("You can install thop to calculate the number of ops.")
logging.info('Command: pip install thop') logging.info("Command: pip install thop")
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
@ -624,11 +620,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -662,18 +654,17 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss = ( loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
# Compute ctc loss # Compute ctc loss
@ -704,16 +695,14 @@ def compute_loss(
) )
assert ctc_loss.requires_grad == is_training assert ctc_loss.requires_grad == is_training
assert 0 <= params.ctc_loss_scale <= 1, "ctc_loss_scale must be between 0 and 1" assert 0 <= params.ctc_loss_scale <= 1, "ctc_loss_scale must be between 0 and 1"
loss = params.ctc_loss_scale * ctc_loss + (1-params.ctc_loss_scale) * loss loss = params.ctc_loss_scale * ctc_loss + (1 - params.ctc_loss_scale) * loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -843,8 +832,9 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
display_and_save_batch(batch, params=params, display_and_save_batch(
sp=sp, phone_lexicon=phone_lexicon) batch, params=params, sp=sp, phone_lexicon=phone_lexicon
)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -896,7 +886,8 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}") logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05: if cur_grad_scale < 1.0e-05:
raise RuntimeError( raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}") f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
@ -918,9 +909,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train "train/grad_scale", cur_grad_scale, params.batch_idx_train
@ -938,11 +927,10 @@ def train_one_epoch(
model.train() model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None: if tb_writer is not None:
valid_info.write_summary( valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value
@ -987,7 +975,7 @@ def run(rank, world_size, args):
if "lang_bpe" in str(params.lang_dir): if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model') sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
@ -1031,8 +1019,7 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank], model = DDP(model, device_ids=[rank], find_unused_parameters=True)
find_unused_parameters=True)
optimizer = AdamW( optimizer = AdamW(
model.parameters(), model.parameters(),
@ -1056,7 +1043,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1108,8 +1095,7 @@ def run(rank, world_size, args):
# params=params, # params=params,
# ) # )
scaler = GradScaler(enabled=params.use_fp16, scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1235,11 +1221,13 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, display_and_save_batch(
sp=sp, phone_lexicon=phone_lexicon) batch, params=params, sp=sp, phone_lexicon=phone_lexicon
)
raise raise
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main(): def main():