black format

This commit is contained in:
wangtiance 2023-10-26 18:34:08 +08:00
parent 7c7c9da05a
commit 9ec52fdb77
7 changed files with 131 additions and 125 deletions

View File

@ -264,8 +264,8 @@ class LibriSpeechAsrDataModule:
num_feature_masks=2,
features_mask_size=5,
num_frame_masks=10,
frames_mask_size=5,
p=0.5,
frames_mask_size=5,
p=0.5,
)
)
else:

View File

@ -373,7 +373,6 @@ def decode_one_batch(
# lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
lm_scale_list = [0.6, 0.7, 0.8, 0.9]
if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
@ -507,9 +506,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@ -576,7 +573,9 @@ def main():
params.blank_id = 0
if params.decoding_method == "ctc-decoding":
assert "lang_bpe" in str(params.lang_dir), "ctc-decoding only supports BPE lexicons."
assert "lang_bpe" in str(
params.lang_dir
), "ctc-decoding only supports BPE lexicons."
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,

View File

@ -39,7 +39,6 @@ from icefall.utils import (
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -493,7 +492,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)

View File

@ -26,7 +26,7 @@ from torch import Tensor, nn
class Conv1dNet(EncoderInterface):
"""
1D Convolution network with causal squeeze and excitation
1D Convolution network with causal squeeze and excitation
module and optional skip connections.
Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride.
@ -34,11 +34,11 @@ class Conv1dNet(EncoderInterface):
Args:
output_dim (int): Number of output channels of the last layer.
input_dim (int): Number of input features
conv_layers (int): Number of convolution layers,
conv_layers (int): Number of convolution layers,
excluding the subsampling layers.
channels (int): Number of output channels for each layer,
channels (int): Number of output channels for each layer,
except the last layer.
subsampling_factor (int): The subsampling factor for the model.
subsampling_factor (int): The subsampling factor for the model.
skip_add (bool): Whether to use skip connection for each convolution layer.
dscnn (bool): Whether to use depthwise-separated convolution.
activation (str): Activation function type.
@ -53,7 +53,7 @@ class Conv1dNet(EncoderInterface):
subsampling_factor: int = 4,
skip_add: bool = False,
dscnn: bool = True,
activation: str = 'relu',
activation: str = "relu",
) -> None:
super().__init__()
assert subsampling_factor == 4, "Only support subsampling = 4"
@ -62,10 +62,12 @@ class Conv1dNet(EncoderInterface):
self.skip_add = skip_add
# 80ms latency for subsample_layer
self.subsample_layer = nn.Sequential(
conv1d_bn_block(input_dim, channels, 9,
stride=2, activation=activation, dscnn=dscnn),
conv1d_bn_block(channels, channels, 5,
stride=2, activation=activation, dscnn=dscnn),
conv1d_bn_block(
input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn
),
conv1d_bn_block(
channels, channels, 5, stride=2, activation=activation, dscnn=dscnn
),
)
self.conv_blocks = nn.ModuleList()
@ -82,13 +84,15 @@ class Conv1dNet(EncoderInterface):
3,
activation=activation,
dscnn=dscnn,
causal=ly % 2),
CausalSqueezeExcite1d(cout[ly], 16, 30)
causal=ly % 2,
),
CausalSqueezeExcite1d(cout[ly], 16, 30),
)
)
def forward(
self, x: torch.Tensor,
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -104,24 +108,25 @@ class Conv1dNet(EncoderInterface):
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.subsample_layer(x)
for idx, layer in enumerate(self.conv_blocks):
if self.skip_add and 0 < idx < self.conv_layers-1:
if self.skip_add and 0 < idx < self.conv_layers - 1:
x = layer(x) + x
else:
x = layer(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
lengths = x_lens >> 2
return x, lengths
def get_activation(name: str,
channels: int,
channel_dim: int = -1,
min_val: int = 0,
max_val: int = 1,
) -> nn.Module:
def get_activation(
name: str,
channels: int,
channel_dim: int = -1,
min_val: int = 0,
max_val: int = 1,
) -> nn.Module:
"""
Get activation function from name in string.
@ -140,44 +145,42 @@ def get_activation(name: str,
"""
act_layer = nn.Identity()
name = name.lower()
if name == 'prelu':
if name == "prelu":
act_layer = nn.PReLU(channels)
elif name == 'relu':
elif name == "relu":
act_layer = nn.ReLU()
elif name == 'relu6':
elif name == "relu6":
act_layer = nn.ReLU6()
elif name == 'hardtanh':
elif name == "hardtanh":
act_layer = nn.Hardtanh(min_val, max_val)
elif name in ['swish', 'silu']:
elif name in ["swish", "silu"]:
act_layer = nn.SiLU()
elif name == 'elu':
elif name == "elu":
act_layer = nn.ELU()
elif name == 'doubleswish':
elif name == "doubleswish":
act_layer = nn.Sequential(
ActivationBalancer(
num_channels=channels,
channel_dim=channel_dim),
ActivationBalancer(num_channels=channels, channel_dim=channel_dim),
DoubleSwish(),
)
elif name == '':
elif name == "":
act_layer = nn.Identity()
else:
raise Exception(f'Unknown activation function: {name}')
raise Exception(f"Unknown activation function: {name}")
return act_layer
class CausalSqueezeExcite1d(nn.Module):
"""
Causal squeeze and excitation module with input and output shape
(batch, channels, time). The global average pooling in the original
Causal squeeze and excitation module with input and output shape
(batch, channels, time). The global average pooling in the original
SE module is replaced by a causal filter, so
the layer does not introduce any algorithmic latency.
the layer does not introduce any algorithmic latency.
Args:
channels (int): Number of channels
reduction (int): channel reduction rate
context_window (int): Context window size for the moving average operation.
context_window (int): Context window size for the moving average operation.
For EMA, the smoothing factor is 1 / context_window.
"""
@ -205,8 +208,8 @@ class CausalSqueezeExcite1d(nn.Module):
self.ema_matrix_size = 0
def _precompute_ema_matrix(self, N: int, device: torch.device):
a = 1.0 / self.context_window # smoothing factor
w = [[(1-a)**k * a for k in range(n, n-N, -1)] for n in range(N)]
a = 1.0 / self.context_window # smoothing factor
w = [[(1 - a) ** k * a for k in range(n, n - N, -1)] for n in range(N)]
w = torch.tensor(w).to(device).tril()
w[:, 0] *= self.context_window
self.ema_matrix = w.T
@ -218,8 +221,8 @@ class CausalSqueezeExcite1d(nn.Module):
y[t] = (1-a) * y[t-1] + a * x[t]
where a = 1 / self.context_window is the smoothing factor.
For training, the iterative version is too slow. A better way is
to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication.
For training, the iterative version is too slow. A better way is
to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication.
The weight matrix can be precomputed if the smoothing factor is fixed.
"""
if self.training:
@ -234,29 +237,29 @@ class CausalSqueezeExcite1d(nn.Module):
y = torch.empty_like(x)
y[:, :, 0] = x[:, :, 0]
for t in range(1, y.shape[-1]):
y[:, :, t] = (1-a) * y[:, :, t-1] + a * x[:, :, t]
y[:, :, t] = (1 - a) * y[:, :, t - 1] + a * x[:, :, t]
return y
def moving_avg(self, x: Tensor) -> Tensor:
"""
Simple moving average with context_window as window size.
Simple moving average with context_window as window size.
"""
y = torch.empty_like(x)
k = min(x.shape[2], self.context_window)
w = [[1/n] * n + [0] * (k-n-1) for n in range(1, k)]
w = [[1 / n] * n + [0] * (k - n - 1) for n in range(1, k)]
w = torch.tensor(w, device=x.device)
y[:, :, :k-1] = torch.matmul(x[:, :, :k-1], w.T)
y[:, :, k-1:] = F.avg_pool1d(x, k, 1)
y[:, :, : k - 1] = torch.matmul(x[:, :, : k - 1], w.T)
y[:, :, k - 1 :] = F.avg_pool1d(x, k, 1)
return y
def forward(self, x: Tensor) -> Tensor:
assert len(x.shape) == 3, "Input is not a 3D tensor!"
y = self.exponential_moving_avg(x)
y = y.permute(0, 2, 1) # make channel last for squeeze op
y = y.permute(0, 2, 1) # make channel last for squeeze op
y = self.act1(self.linear1(y))
y = self.act2(self.linear2(y))
y = y.permute(0, 2, 1) # back to original shape
y = y.permute(0, 2, 1) # back to original shape
y = x * y
return y
@ -267,12 +270,12 @@ def conv1d_bn_block(
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
activation: str = 'relu',
activation: str = "relu",
dscnn: bool = False,
causal: bool = False,
) -> nn.Sequential:
"""
Conv1d - batchnorm - activation block.
Conv1d - batchnorm - activation block.
If kernel size is even, output length = input length + 1.
Otherwise, output and input lengths are equal.
@ -296,16 +299,19 @@ def conv1d_bn_block(
stride=stride,
dilation=dilation,
groups=in_channels,
bias=False) if causal else
nn.Conv1d(
bias=False,
)
if causal
else nn.Conv1d(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=(kernel_size//2) * dilation,
padding=(kernel_size // 2) * dilation,
dilation=dilation,
groups=in_channels,
bias=False),
bias=False,
),
nn.BatchNorm1d(in_channels),
get_activation(activation, in_channels),
nn.Conv1d(in_channels, out_channels, 1, bias=False),
@ -320,15 +326,18 @@ def conv1d_bn_block(
kernel_size,
stride=stride,
dilation=dilation,
bias=False) if causal else
nn.Conv1d(
bias=False,
)
if causal
else nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=(kernel_size//2) * dilation,
padding=(kernel_size // 2) * dilation,
dilation=dilation,
bias=False),
bias=False,
),
nn.BatchNorm1d(out_channels),
get_activation(activation, out_channels),
)
@ -336,7 +345,7 @@ def conv1d_bn_block(
class CausalConv1d(nn.Module):
"""
Causal convolution with padding automatically chosen to match input/output length.
Causal convolution with padding automatically chosen to match input/output length.
"""
def __init__(
@ -352,11 +361,19 @@ class CausalConv1d(nn.Module):
super(CausalConv1d, self).__init__()
assert kernel_size > 2
self.padding = dilation*(kernel_size-1)
self.padding = dilation * (kernel_size - 1)
self.stride = stride
self.conv = nn.Conv1d(in_channels, out_channels,
kernel_size, stride, self.padding, dilation, groups, bias=bias)
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
self.padding,
dilation,
groups,
bias=bias,
)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)[:, :, :-self.padding // self.stride]
return self.conv(x)[:, :, : -self.padding // self.stride]

View File

@ -78,9 +78,12 @@ from pathlib import Path
import sentencepiece as spm
import torch
from icefall.checkpoint import (average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints, load_checkpoint)
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import UniqLexicon
from icefall.utils import str2bool
from train import add_model_arguments, get_params, get_transducer_model
@ -188,7 +191,7 @@ def main():
if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model')
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()

View File

@ -228,7 +228,7 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model')
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")

View File

@ -79,15 +79,12 @@ from icefall.utils import (
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def set_batch_count(
model: Union[nn.Module, DDP],
batch_count: float
) -> None:
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
if hasattr(module, 'batch_count'):
if hasattr(module, "batch_count"):
module.batch_count = batch_count
@ -140,7 +137,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="""Use skip connection in the encoder.
""",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -226,8 +223,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -250,8 +246,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -450,7 +445,7 @@ def get_transducer_model(params: AttributeDict) -> Transducer:
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
if is_module_available('thop'):
if is_module_available("thop"):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
@ -459,13 +454,14 @@ def get_transducer_model(params: AttributeDict) -> Transducer:
x = torch.zeros((1, 1000, params.feature_dim)).to(device)
x_lens = torch.Tensor([1000]).int().to(device)
from thop import clever_format, profile
m = copy.deepcopy(encoder)
m = m.to(device)
ops, _ = clever_format(profile(m, (x, x_lens), verbose=False))
logging.info(f'Encoder MAC ops for 10 seconds of audio is {ops}')
logging.info(f"Encoder MAC ops for 10 seconds of audio is {ops}")
else:
logging.info('You can install thop to calculate the number of ops.')
logging.info('Command: pip install thop')
logging.info("You can install thop to calculate the number of ops.")
logging.info("Command: pip install thop")
model = Transducer(
encoder=encoder,
@ -624,11 +620,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -662,18 +654,17 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = (
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
# Compute ctc loss
@ -704,16 +695,14 @@ def compute_loss(
)
assert ctc_loss.requires_grad == is_training
assert 0 <= params.ctc_loss_scale <= 1, "ctc_loss_scale must be between 0 and 1"
loss = params.ctc_loss_scale * ctc_loss + (1-params.ctc_loss_scale) * loss
loss = params.ctc_loss_scale * ctc_loss + (1 - params.ctc_loss_scale) * loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -843,8 +832,9 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params,
sp=sp, phone_lexicon=phone_lexicon)
display_and_save_batch(
batch, params=params, sp=sp, phone_lexicon=phone_lexicon
)
raise
if params.print_diagnostics and batch_idx == 5:
@ -896,7 +886,8 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}")
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@ -918,9 +909,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
@ -938,11 +927,10 @@ def train_one_epoch(
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -987,7 +975,7 @@ def run(rank, world_size, args):
if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model')
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
@ -1031,8 +1019,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = AdamW(
model.parameters(),
@ -1056,7 +1043,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1108,8 +1095,7 @@ def run(rank, world_size, args):
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16,
init_scale=1.0)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1235,11 +1221,13 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params,
sp=sp, phone_lexicon=phone_lexicon)
display_and_save_batch(
batch, params=params, sp=sp, phone_lexicon=phone_lexicon
)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main():