mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
black format
This commit is contained in:
parent
7c7c9da05a
commit
9ec52fdb77
@ -264,8 +264,8 @@ class LibriSpeechAsrDataModule:
|
||||
num_feature_masks=2,
|
||||
features_mask_size=5,
|
||||
num_frame_masks=10,
|
||||
frames_mask_size=5,
|
||||
p=0.5,
|
||||
frames_mask_size=5,
|
||||
p=0.5,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -373,7 +373,6 @@ def decode_one_batch(
|
||||
# lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
lm_scale_list = [0.6, 0.7, 0.8, 0.9]
|
||||
|
||||
|
||||
if params.decoding_method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
@ -507,9 +506,7 @@ def save_results(
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
@ -576,7 +573,9 @@ def main():
|
||||
params.blank_id = 0
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
assert "lang_bpe" in str(params.lang_dir), "ctc-decoding only supports BPE lexicons."
|
||||
assert "lang_bpe" in str(
|
||||
params.lang_dir
|
||||
), "ctc-decoding only supports BPE lexicons."
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
|
@ -39,7 +39,6 @@ from icefall.utils import (
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -493,7 +492,7 @@ def save_results(
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
/ f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
|
@ -26,7 +26,7 @@ from torch import Tensor, nn
|
||||
|
||||
class Conv1dNet(EncoderInterface):
|
||||
"""
|
||||
1D Convolution network with causal squeeze and excitation
|
||||
1D Convolution network with causal squeeze and excitation
|
||||
module and optional skip connections.
|
||||
|
||||
Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride.
|
||||
@ -34,11 +34,11 @@ class Conv1dNet(EncoderInterface):
|
||||
Args:
|
||||
output_dim (int): Number of output channels of the last layer.
|
||||
input_dim (int): Number of input features
|
||||
conv_layers (int): Number of convolution layers,
|
||||
conv_layers (int): Number of convolution layers,
|
||||
excluding the subsampling layers.
|
||||
channels (int): Number of output channels for each layer,
|
||||
channels (int): Number of output channels for each layer,
|
||||
except the last layer.
|
||||
subsampling_factor (int): The subsampling factor for the model.
|
||||
subsampling_factor (int): The subsampling factor for the model.
|
||||
skip_add (bool): Whether to use skip connection for each convolution layer.
|
||||
dscnn (bool): Whether to use depthwise-separated convolution.
|
||||
activation (str): Activation function type.
|
||||
@ -53,7 +53,7 @@ class Conv1dNet(EncoderInterface):
|
||||
subsampling_factor: int = 4,
|
||||
skip_add: bool = False,
|
||||
dscnn: bool = True,
|
||||
activation: str = 'relu',
|
||||
activation: str = "relu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert subsampling_factor == 4, "Only support subsampling = 4"
|
||||
@ -62,10 +62,12 @@ class Conv1dNet(EncoderInterface):
|
||||
self.skip_add = skip_add
|
||||
# 80ms latency for subsample_layer
|
||||
self.subsample_layer = nn.Sequential(
|
||||
conv1d_bn_block(input_dim, channels, 9,
|
||||
stride=2, activation=activation, dscnn=dscnn),
|
||||
conv1d_bn_block(channels, channels, 5,
|
||||
stride=2, activation=activation, dscnn=dscnn),
|
||||
conv1d_bn_block(
|
||||
input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn
|
||||
),
|
||||
conv1d_bn_block(
|
||||
channels, channels, 5, stride=2, activation=activation, dscnn=dscnn
|
||||
),
|
||||
)
|
||||
|
||||
self.conv_blocks = nn.ModuleList()
|
||||
@ -82,13 +84,15 @@ class Conv1dNet(EncoderInterface):
|
||||
3,
|
||||
activation=activation,
|
||||
dscnn=dscnn,
|
||||
causal=ly % 2),
|
||||
CausalSqueezeExcite1d(cout[ly], 16, 30)
|
||||
causal=ly % 2,
|
||||
),
|
||||
CausalSqueezeExcite1d(cout[ly], 16, 30),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -104,24 +108,25 @@ class Conv1dNet(EncoderInterface):
|
||||
- lengths, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.subsample_layer(x)
|
||||
for idx, layer in enumerate(self.conv_blocks):
|
||||
if self.skip_add and 0 < idx < self.conv_layers-1:
|
||||
if self.skip_add and 0 < idx < self.conv_layers - 1:
|
||||
x = layer(x) + x
|
||||
else:
|
||||
x = layer(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
lengths = x_lens >> 2
|
||||
return x, lengths
|
||||
|
||||
|
||||
def get_activation(name: str,
|
||||
channels: int,
|
||||
channel_dim: int = -1,
|
||||
min_val: int = 0,
|
||||
max_val: int = 1,
|
||||
) -> nn.Module:
|
||||
def get_activation(
|
||||
name: str,
|
||||
channels: int,
|
||||
channel_dim: int = -1,
|
||||
min_val: int = 0,
|
||||
max_val: int = 1,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Get activation function from name in string.
|
||||
|
||||
@ -140,44 +145,42 @@ def get_activation(name: str,
|
||||
"""
|
||||
act_layer = nn.Identity()
|
||||
name = name.lower()
|
||||
if name == 'prelu':
|
||||
if name == "prelu":
|
||||
act_layer = nn.PReLU(channels)
|
||||
elif name == 'relu':
|
||||
elif name == "relu":
|
||||
act_layer = nn.ReLU()
|
||||
elif name == 'relu6':
|
||||
elif name == "relu6":
|
||||
act_layer = nn.ReLU6()
|
||||
elif name == 'hardtanh':
|
||||
elif name == "hardtanh":
|
||||
act_layer = nn.Hardtanh(min_val, max_val)
|
||||
elif name in ['swish', 'silu']:
|
||||
elif name in ["swish", "silu"]:
|
||||
act_layer = nn.SiLU()
|
||||
elif name == 'elu':
|
||||
elif name == "elu":
|
||||
act_layer = nn.ELU()
|
||||
elif name == 'doubleswish':
|
||||
elif name == "doubleswish":
|
||||
act_layer = nn.Sequential(
|
||||
ActivationBalancer(
|
||||
num_channels=channels,
|
||||
channel_dim=channel_dim),
|
||||
ActivationBalancer(num_channels=channels, channel_dim=channel_dim),
|
||||
DoubleSwish(),
|
||||
)
|
||||
elif name == '':
|
||||
elif name == "":
|
||||
act_layer = nn.Identity()
|
||||
else:
|
||||
raise Exception(f'Unknown activation function: {name}')
|
||||
raise Exception(f"Unknown activation function: {name}")
|
||||
|
||||
return act_layer
|
||||
|
||||
|
||||
class CausalSqueezeExcite1d(nn.Module):
|
||||
"""
|
||||
Causal squeeze and excitation module with input and output shape
|
||||
(batch, channels, time). The global average pooling in the original
|
||||
Causal squeeze and excitation module with input and output shape
|
||||
(batch, channels, time). The global average pooling in the original
|
||||
SE module is replaced by a causal filter, so
|
||||
the layer does not introduce any algorithmic latency.
|
||||
the layer does not introduce any algorithmic latency.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels
|
||||
reduction (int): channel reduction rate
|
||||
context_window (int): Context window size for the moving average operation.
|
||||
context_window (int): Context window size for the moving average operation.
|
||||
For EMA, the smoothing factor is 1 / context_window.
|
||||
"""
|
||||
|
||||
@ -205,8 +208,8 @@ class CausalSqueezeExcite1d(nn.Module):
|
||||
self.ema_matrix_size = 0
|
||||
|
||||
def _precompute_ema_matrix(self, N: int, device: torch.device):
|
||||
a = 1.0 / self.context_window # smoothing factor
|
||||
w = [[(1-a)**k * a for k in range(n, n-N, -1)] for n in range(N)]
|
||||
a = 1.0 / self.context_window # smoothing factor
|
||||
w = [[(1 - a) ** k * a for k in range(n, n - N, -1)] for n in range(N)]
|
||||
w = torch.tensor(w).to(device).tril()
|
||||
w[:, 0] *= self.context_window
|
||||
self.ema_matrix = w.T
|
||||
@ -218,8 +221,8 @@ class CausalSqueezeExcite1d(nn.Module):
|
||||
y[t] = (1-a) * y[t-1] + a * x[t]
|
||||
where a = 1 / self.context_window is the smoothing factor.
|
||||
|
||||
For training, the iterative version is too slow. A better way is
|
||||
to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication.
|
||||
For training, the iterative version is too slow. A better way is
|
||||
to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication.
|
||||
The weight matrix can be precomputed if the smoothing factor is fixed.
|
||||
"""
|
||||
if self.training:
|
||||
@ -234,29 +237,29 @@ class CausalSqueezeExcite1d(nn.Module):
|
||||
y = torch.empty_like(x)
|
||||
y[:, :, 0] = x[:, :, 0]
|
||||
for t in range(1, y.shape[-1]):
|
||||
y[:, :, t] = (1-a) * y[:, :, t-1] + a * x[:, :, t]
|
||||
y[:, :, t] = (1 - a) * y[:, :, t - 1] + a * x[:, :, t]
|
||||
return y
|
||||
|
||||
def moving_avg(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Simple moving average with context_window as window size.
|
||||
Simple moving average with context_window as window size.
|
||||
"""
|
||||
y = torch.empty_like(x)
|
||||
k = min(x.shape[2], self.context_window)
|
||||
w = [[1/n] * n + [0] * (k-n-1) for n in range(1, k)]
|
||||
w = [[1 / n] * n + [0] * (k - n - 1) for n in range(1, k)]
|
||||
w = torch.tensor(w, device=x.device)
|
||||
y[:, :, :k-1] = torch.matmul(x[:, :, :k-1], w.T)
|
||||
y[:, :, k-1:] = F.avg_pool1d(x, k, 1)
|
||||
y[:, :, : k - 1] = torch.matmul(x[:, :, : k - 1], w.T)
|
||||
y[:, :, k - 1 :] = F.avg_pool1d(x, k, 1)
|
||||
return y
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
||||
assert len(x.shape) == 3, "Input is not a 3D tensor!"
|
||||
y = self.exponential_moving_avg(x)
|
||||
y = y.permute(0, 2, 1) # make channel last for squeeze op
|
||||
y = y.permute(0, 2, 1) # make channel last for squeeze op
|
||||
y = self.act1(self.linear1(y))
|
||||
y = self.act2(self.linear2(y))
|
||||
y = y.permute(0, 2, 1) # back to original shape
|
||||
y = y.permute(0, 2, 1) # back to original shape
|
||||
y = x * y
|
||||
return y
|
||||
|
||||
@ -267,12 +270,12 @@ def conv1d_bn_block(
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
activation: str = 'relu',
|
||||
activation: str = "relu",
|
||||
dscnn: bool = False,
|
||||
causal: bool = False,
|
||||
) -> nn.Sequential:
|
||||
"""
|
||||
Conv1d - batchnorm - activation block.
|
||||
Conv1d - batchnorm - activation block.
|
||||
If kernel size is even, output length = input length + 1.
|
||||
Otherwise, output and input lengths are equal.
|
||||
|
||||
@ -296,16 +299,19 @@ def conv1d_bn_block(
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
bias=False) if causal else
|
||||
nn.Conv1d(
|
||||
bias=False,
|
||||
)
|
||||
if causal
|
||||
else nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size//2) * dilation,
|
||||
padding=(kernel_size // 2) * dilation,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
bias=False),
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm1d(in_channels),
|
||||
get_activation(activation, in_channels),
|
||||
nn.Conv1d(in_channels, out_channels, 1, bias=False),
|
||||
@ -320,15 +326,18 @@ def conv1d_bn_block(
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
bias=False) if causal else
|
||||
nn.Conv1d(
|
||||
bias=False,
|
||||
)
|
||||
if causal
|
||||
else nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size//2) * dilation,
|
||||
padding=(kernel_size // 2) * dilation,
|
||||
dilation=dilation,
|
||||
bias=False),
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
get_activation(activation, out_channels),
|
||||
)
|
||||
@ -336,7 +345,7 @@ def conv1d_bn_block(
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
"""
|
||||
Causal convolution with padding automatically chosen to match input/output length.
|
||||
Causal convolution with padding automatically chosen to match input/output length.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -352,11 +361,19 @@ class CausalConv1d(nn.Module):
|
||||
super(CausalConv1d, self).__init__()
|
||||
assert kernel_size > 2
|
||||
|
||||
self.padding = dilation*(kernel_size-1)
|
||||
self.padding = dilation * (kernel_size - 1)
|
||||
self.stride = stride
|
||||
|
||||
self.conv = nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size, stride, self.padding, dilation, groups, bias=bias)
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
self.padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.conv(x)[:, :, :-self.padding // self.stride]
|
||||
return self.conv(x)[:, :, : -self.padding // self.stride]
|
||||
|
@ -78,9 +78,12 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from icefall.checkpoint import (average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints, load_checkpoint)
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import UniqLexicon
|
||||
from icefall.utils import str2bool
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
@ -188,7 +191,7 @@ def main():
|
||||
|
||||
if "lang_bpe" in str(params.lang_dir):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.lang_dir + '/bpe.model')
|
||||
sp.load(params.lang_dir + "/bpe.model")
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
@ -228,7 +228,7 @@ def main():
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.lang_dir + '/bpe.model')
|
||||
sp.load(params.lang_dir + "/bpe.model")
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
|
@ -79,15 +79,12 @@ from icefall.utils import (
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
|
||||
def set_batch_count(
|
||||
model: Union[nn.Module, DDP],
|
||||
batch_count: float
|
||||
) -> None:
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'batch_count'):
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
@ -140,7 +137,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="""Use skip connection in the encoder.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -226,8 +223,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -250,8 +246,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -450,7 +445,7 @@ def get_transducer_model(params: AttributeDict) -> Transducer:
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
if is_module_available('thop'):
|
||||
if is_module_available("thop"):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
else:
|
||||
@ -459,13 +454,14 @@ def get_transducer_model(params: AttributeDict) -> Transducer:
|
||||
x = torch.zeros((1, 1000, params.feature_dim)).to(device)
|
||||
x_lens = torch.Tensor([1000]).int().to(device)
|
||||
from thop import clever_format, profile
|
||||
|
||||
m = copy.deepcopy(encoder)
|
||||
m = m.to(device)
|
||||
ops, _ = clever_format(profile(m, (x, x_lens), verbose=False))
|
||||
logging.info(f'Encoder MAC ops for 10 seconds of audio is {ops}')
|
||||
logging.info(f"Encoder MAC ops for 10 seconds of audio is {ops}")
|
||||
else:
|
||||
logging.info('You can install thop to calculate the number of ops.')
|
||||
logging.info('Command: pip install thop')
|
||||
logging.info("You can install thop to calculate the number of ops.")
|
||||
logging.info("Command: pip install thop")
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
@ -624,11 +620,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -662,18 +654,17 @@ def compute_loss(
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s if batch_idx_train >= warm_step
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0 if batch_idx_train >= warm_step
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
|
||||
loss = (
|
||||
simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
# Compute ctc loss
|
||||
|
||||
@ -704,16 +695,14 @@ def compute_loss(
|
||||
)
|
||||
assert ctc_loss.requires_grad == is_training
|
||||
assert 0 <= params.ctc_loss_scale <= 1, "ctc_loss_scale must be between 0 and 1"
|
||||
loss = params.ctc_loss_scale * ctc_loss + (1-params.ctc_loss_scale) * loss
|
||||
loss = params.ctc_loss_scale * ctc_loss + (1 - params.ctc_loss_scale) * loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -843,8 +832,9 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params,
|
||||
sp=sp, phone_lexicon=phone_lexicon)
|
||||
display_and_save_batch(
|
||||
batch, params=params, sp=sp, phone_lexicon=phone_lexicon
|
||||
)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
@ -896,7 +886,8 @@ def train_one_epoch(
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}")
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
@ -918,9 +909,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
@ -938,11 +927,10 @@ def train_one_epoch(
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
@ -987,7 +975,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if "lang_bpe" in str(params.lang_dir):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.lang_dir + '/bpe.model')
|
||||
sp.load(params.lang_dir + "/bpe.model")
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
@ -1031,8 +1019,7 @@ def run(rank, world_size, args):
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank],
|
||||
find_unused_parameters=True)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = AdamW(
|
||||
model.parameters(),
|
||||
@ -1056,7 +1043,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
@ -1108,8 +1095,7 @@ def run(rank, world_size, args):
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16,
|
||||
init_scale=1.0)
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
@ -1235,11 +1221,13 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(batch, params=params,
|
||||
sp=sp, phone_lexicon=phone_lexicon)
|
||||
display_and_save_batch(
|
||||
batch, params=params, sp=sp, phone_lexicon=phone_lexicon
|
||||
)
|
||||
raise
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
x
Reference in New Issue
Block a user