black format

This commit is contained in:
wangtiance 2023-10-26 18:34:08 +08:00
parent 7c7c9da05a
commit 9ec52fdb77
7 changed files with 131 additions and 125 deletions

View File

@ -373,7 +373,6 @@ def decode_one_batch(
# lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
lm_scale_list = [0.6, 0.7, 0.8, 0.9]
if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
@ -507,9 +506,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
@ -576,7 +573,9 @@ def main():
params.blank_id = 0
if params.decoding_method == "ctc-decoding":
assert "lang_bpe" in str(params.lang_dir), "ctc-decoding only supports BPE lexicons."
assert "lang_bpe" in str(
params.lang_dir
), "ctc-decoding only supports BPE lexicons."
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,

View File

@ -39,7 +39,6 @@ from icefall.utils import (
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -493,7 +492,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
/ f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)

View File

@ -53,7 +53,7 @@ class Conv1dNet(EncoderInterface):
subsampling_factor: int = 4,
skip_add: bool = False,
dscnn: bool = True,
activation: str = 'relu',
activation: str = "relu",
) -> None:
super().__init__()
assert subsampling_factor == 4, "Only support subsampling = 4"
@ -62,10 +62,12 @@ class Conv1dNet(EncoderInterface):
self.skip_add = skip_add
# 80ms latency for subsample_layer
self.subsample_layer = nn.Sequential(
conv1d_bn_block(input_dim, channels, 9,
stride=2, activation=activation, dscnn=dscnn),
conv1d_bn_block(channels, channels, 5,
stride=2, activation=activation, dscnn=dscnn),
conv1d_bn_block(
input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn
),
conv1d_bn_block(
channels, channels, 5, stride=2, activation=activation, dscnn=dscnn
),
)
self.conv_blocks = nn.ModuleList()
@ -82,13 +84,15 @@ class Conv1dNet(EncoderInterface):
3,
activation=activation,
dscnn=dscnn,
causal=ly % 2),
CausalSqueezeExcite1d(cout[ly], 16, 30)
causal=ly % 2,
),
CausalSqueezeExcite1d(cout[ly], 16, 30),
)
)
def forward(
self, x: torch.Tensor,
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -116,7 +120,8 @@ class Conv1dNet(EncoderInterface):
return x, lengths
def get_activation(name: str,
def get_activation(
name: str,
channels: int,
channel_dim: int = -1,
min_val: int = 0,
@ -140,29 +145,27 @@ def get_activation(name: str,
"""
act_layer = nn.Identity()
name = name.lower()
if name == 'prelu':
if name == "prelu":
act_layer = nn.PReLU(channels)
elif name == 'relu':
elif name == "relu":
act_layer = nn.ReLU()
elif name == 'relu6':
elif name == "relu6":
act_layer = nn.ReLU6()
elif name == 'hardtanh':
elif name == "hardtanh":
act_layer = nn.Hardtanh(min_val, max_val)
elif name in ['swish', 'silu']:
elif name in ["swish", "silu"]:
act_layer = nn.SiLU()
elif name == 'elu':
elif name == "elu":
act_layer = nn.ELU()
elif name == 'doubleswish':
elif name == "doubleswish":
act_layer = nn.Sequential(
ActivationBalancer(
num_channels=channels,
channel_dim=channel_dim),
ActivationBalancer(num_channels=channels, channel_dim=channel_dim),
DoubleSwish(),
)
elif name == '':
elif name == "":
act_layer = nn.Identity()
else:
raise Exception(f'Unknown activation function: {name}')
raise Exception(f"Unknown activation function: {name}")
return act_layer
@ -267,7 +270,7 @@ def conv1d_bn_block(
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
activation: str = 'relu',
activation: str = "relu",
dscnn: bool = False,
causal: bool = False,
) -> nn.Sequential:
@ -296,8 +299,10 @@ def conv1d_bn_block(
stride=stride,
dilation=dilation,
groups=in_channels,
bias=False) if causal else
nn.Conv1d(
bias=False,
)
if causal
else nn.Conv1d(
in_channels,
in_channels,
kernel_size,
@ -305,7 +310,8 @@ def conv1d_bn_block(
padding=(kernel_size // 2) * dilation,
dilation=dilation,
groups=in_channels,
bias=False),
bias=False,
),
nn.BatchNorm1d(in_channels),
get_activation(activation, in_channels),
nn.Conv1d(in_channels, out_channels, 1, bias=False),
@ -320,15 +326,18 @@ def conv1d_bn_block(
kernel_size,
stride=stride,
dilation=dilation,
bias=False) if causal else
nn.Conv1d(
bias=False,
)
if causal
else nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=(kernel_size // 2) * dilation,
dilation=dilation,
bias=False),
bias=False,
),
nn.BatchNorm1d(out_channels),
get_activation(activation, out_channels),
)
@ -355,8 +364,16 @@ class CausalConv1d(nn.Module):
self.padding = dilation * (kernel_size - 1)
self.stride = stride
self.conv = nn.Conv1d(in_channels, out_channels,
kernel_size, stride, self.padding, dilation, groups, bias=bias)
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
self.padding,
dilation,
groups,
bias=bias,
)
def forward(self, x: Tensor) -> Tensor:
return self.conv(x)[:, :, : -self.padding // self.stride]

View File

@ -78,9 +78,12 @@ from pathlib import Path
import sentencepiece as spm
import torch
from icefall.checkpoint import (average_checkpoints,
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints, load_checkpoint)
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import UniqLexicon
from icefall.utils import str2bool
from train import add_model_arguments, get_params, get_transducer_model
@ -188,7 +191,7 @@ def main():
if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model')
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()

View File

@ -228,7 +228,7 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model')
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")

View File

@ -79,15 +79,12 @@ from icefall.utils import (
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def set_batch_count(
model: Union[nn.Module, DDP],
batch_count: float
) -> None:
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
if hasattr(module, 'batch_count'):
if hasattr(module, "batch_count"):
module.batch_count = batch_count
@ -226,8 +223,7 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -250,8 +246,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -450,7 +445,7 @@ def get_transducer_model(params: AttributeDict) -> Transducer:
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
if is_module_available('thop'):
if is_module_available("thop"):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
@ -459,13 +454,14 @@ def get_transducer_model(params: AttributeDict) -> Transducer:
x = torch.zeros((1, 1000, params.feature_dim)).to(device)
x_lens = torch.Tensor([1000]).int().to(device)
from thop import clever_format, profile
m = copy.deepcopy(encoder)
m = m.to(device)
ops, _ = clever_format(profile(m, (x, x_lens), verbose=False))
logging.info(f'Encoder MAC ops for 10 seconds of audio is {ops}')
logging.info(f"Encoder MAC ops for 10 seconds of audio is {ops}")
else:
logging.info('You can install thop to calculate the number of ops.')
logging.info('Command: pip install thop')
logging.info("You can install thop to calculate the number of ops.")
logging.info("Command: pip install thop")
model = Transducer(
encoder=encoder,
@ -624,11 +620,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -662,18 +654,17 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = (
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
# Compute ctc loss
@ -711,9 +702,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -843,8 +832,9 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params,
sp=sp, phone_lexicon=phone_lexicon)
display_and_save_batch(
batch, params=params, sp=sp, phone_lexicon=phone_lexicon
)
raise
if params.print_diagnostics and batch_idx == 5:
@ -896,7 +886,8 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}")
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@ -918,9 +909,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
@ -938,11 +927,10 @@ def train_one_epoch(
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
@ -987,7 +975,7 @@ def run(rank, world_size, args):
if "lang_bpe" in str(params.lang_dir):
sp = spm.SentencePieceProcessor()
sp.load(params.lang_dir + '/bpe.model')
sp.load(params.lang_dir + "/bpe.model")
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
@ -1031,8 +1019,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank],
find_unused_parameters=True)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = AdamW(
model.parameters(),
@ -1108,8 +1095,7 @@ def run(rank, world_size, args):
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16,
init_scale=1.0)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1235,11 +1221,13 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params,
sp=sp, phone_lexicon=phone_lexicon)
display_and_save_batch(
batch, params=params, sp=sp, phone_lexicon=phone_lexicon
)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def main():