Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-09-24 20:34:34 +08:00
parent 6f5d63492a
commit 9e6bd0f07c
5 changed files with 119 additions and 149 deletions

View File

@ -99,7 +99,7 @@ def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.

View File

@ -1,4 +1,21 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@ -11,21 +28,20 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from tdnn_lstm_ctc.model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.mmi import LFMMILoss
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
encode_supervisions,
@ -61,28 +77,22 @@ def get_parser():
)
parser.add_argument(
"--use-ali-model",
type=str2bool,
default=True,
help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py "
"and you have some checkpoints inside the directory "
"tdnn_lstm_ctc/exp_bpe_500 ."
"It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt "
"as the pre-trained alignment model",
)
parser.add_argument(
"--ali-model-epoch",
"--num-epochs",
type=int,
default=19,
help="If --use-ali-model is True, load "
"tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as "
"the alignment model."
"Used only if --use-ali-model is True.",
default=50,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
conformer_mmi/exp/epoch-{start_epoch-1}.pt
""",
)
# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser
@ -90,7 +100,7 @@ def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
@ -103,20 +113,6 @@ def get_params() -> AttributeDict:
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -135,36 +131,60 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- head: Number of heads of multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- lr_factor: The lr_factor for Noam optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_mmi/exp_500"),
"lang_dir": Path("data/lang_bpe_500"),
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 50,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 10,
"use_pruned_intersect": False,
"den_scale": 1.0,
#
"att_rate": 0.7,
"valid_interval": 3000,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
"is_espnet_structure": True,
"use_feat_batchnorm": True,
# parameters for loss
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
"att_rate": 0.7,
# parameters for Noam
"weight_decay": 1e-6,
"lr_factor": 5.0,
"warm_step": 80000,
"use_pruned_intersect": False,
"den_scale": 1.0,
}
)
@ -261,13 +281,12 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: nn.Module,
ali_model: Optional[nn.Module],
batch: dict,
graph_compiler: BpeMmiTrainingGraphCompiler,
graph_compiler: MmiTrainingGraphCompiler,
is_training: bool,
):
"""
Compute MMI loss given the model and its inputs.
Compute LF-MMI loss given the model and its inputs.
Args:
params:
@ -278,7 +297,9 @@ def compute_loss(
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build num_graphs and den_graphs.
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
@ -286,54 +307,34 @@ def compute_loss(
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C]
if ali_model is not None and params.batch_idx_train < 4000:
feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T]
ali_model_output = ali_model(feature)
# subsampling is done slightly differently, may be small length
# differences.
min_len = min(ali_model_output.shape[1], nnet_output.shape[1])
# scale less than one so it will be encouraged
# to mimic ali_model's output
ali_model_scale = 500.0 / (params.batch_idx_train + 500)
# nnet_output is (N, T, C)
# Use clone() here or log-softmax backprop will fail.
nnet_output = nnet_output.clone()
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `LFMMILoss.forward()`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
nnet_output[:, :min_len, :] += (
ali_model_scale * ali_model_output[:, :min_len, :]
)
loss_fn = LFMMILoss(
graph_compiler=graph_compiler,
use_pruned_intersect=params.use_pruned_intersect,
den_scale=params.den_scale,
)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in LFMMILoss
#
# TODO: If params.use_pruned_intersect is True, there is no
# need to call encode_supervisions
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
loss_fn = LFMMILoss(
graph_compiler=graph_compiler,
den_scale=params.den_scale,
use_pruned_intersect=params.use_pruned_intersect,
)
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
if params.att_rate != 0.0:
token_ids = graph_compiler.texts_to_ids(texts)
@ -373,8 +374,7 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
ali_model: Optional[nn.Module],
graph_compiler: BpeMmiTrainingGraphCompiler,
graph_compiler: MmiTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> None:
@ -391,7 +391,6 @@ def compute_validation_loss(
loss, mmi_loss, att_loss = compute_loss(
params=params,
model=model,
ali_model=ali_model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
@ -432,9 +431,8 @@ def compute_validation_loss(
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
ali_model: Optional[nn.Module],
optimizer: torch.optim.Optimizer,
graph_compiler: BpeMmiTrainingGraphCompiler,
graph_compiler: MmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
@ -451,9 +449,6 @@ def train_one_epoch(
It is returned by :func:`get_params`.
model:
The model for training.
ali_model:
The force alignment model for training. It is from
tdnn_lstm_ctc/train_bpe.py
optimizer:
The optimizer we are using.
graph_compiler:
@ -483,7 +478,6 @@ def train_one_epoch(
loss, mmi_loss, att_loss = compute_loss(
params=params,
model=model,
ali_model=ali_model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
@ -494,7 +488,7 @@ def train_one_epoch(
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0)
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
@ -519,7 +513,7 @@ def train_one_epoch(
f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg mmi loss: {tot_avg_mmi_loss:.4f}, "
f"total avg mmiloss: {tot_avg_mmi_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
@ -568,7 +562,6 @@ def train_one_epoch(
compute_validation_loss(
params=params,
model=model,
ali_model=ali_model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
@ -576,10 +569,10 @@ def train_one_epoch(
model.train()
logging.info(
f"Epoch {params.cur_epoch}, "
f"valid mmi loss {params.valid_mmi_loss:.4f}, "
f"valid att loss {params.valid_att_loss:.4f}, "
f"valid loss {params.valid_loss:.4f}, "
f"best valid loss: {params.best_valid_loss:.4f}, "
f"valid mmi loss {params.valid_mmi_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None:
@ -642,11 +635,13 @@ def run(rank, world_size, args):
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = BpeMmiTrainingGraphCompiler(
graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
uniq_filename="lexicon.txt",
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
oov="<UNK>",
sos_id=1,
eos_id=1,
)
logging.info("About to create model")
@ -658,7 +653,6 @@ def run(rank, world_size, args):
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
use_feat_batchnorm=params.use_feat_batchnorm,
)
@ -679,32 +673,6 @@ def run(rank, world_size, args):
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
if args.use_ali_model:
ali_model = TdnnLstm(
num_features=params.feature_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
)
ali_model_fname = Path(
f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt"
)
assert (
ali_model_fname.is_file()
), f"ali model filename {ali_model_fname} does not exist!"
ali_model.load_state_dict(
torch.load(ali_model_fname, map_location="cpu")["model"]
)
ali_model.to(device)
ali_model.eval()
ali_model.requires_grad_(False)
logging.info(f"Use ali_model: {ali_model_fname}")
else:
ali_model = None
logging.info("No ali_model")
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
@ -727,7 +695,6 @@ def run(rank, world_size, args):
train_one_epoch(
params=params,
model=model,
ali_model=ali_model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,

View File

@ -227,5 +227,3 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
./local/compile_hlg.py --lang-dir $lang_dir
done
fi
cd data && ln -sfv lang_bpe_500 lang_bpe

View File

@ -4,13 +4,13 @@ import k2
import torch
from torch import nn
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
def _compute_mmi_loss_exact_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler,
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
@ -98,7 +98,7 @@ def _compute_mmi_loss_exact_optimized(
def _compute_mmi_loss_exact_non_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler,
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
@ -133,7 +133,7 @@ def _compute_mmi_loss_exact_non_optimized(
def _compute_mmi_loss_pruned(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler,
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
@ -184,7 +184,7 @@ class LFMMILoss(nn.Module):
def __init__(
self,
graph_compiler: BpeMmiTrainingGraphCompiler,
graph_compiler: MmiTrainingGraphCompiler,
use_pruned_intersect: bool = False,
den_scale: float = 1.0,
):

View File

@ -15,6 +15,8 @@ class MmiTrainingGraphCompiler(object):
uniq_filename: str = "uniq_lexicon.txt",
device: Union[str, torch.device] = "cpu",
oov: str = "<UNK>",
sos_id: int = 1,
eos_id: int = 1,
):
"""
Args:
@ -45,6 +47,8 @@ class MmiTrainingGraphCompiler(object):
self.L_inv = self.lexicon.L_inv.to(self.device)
self.oov_id = self.lexicon.word_table[oov]
self.sos_id = sos_id
self.eos_id = eos_id
self.build_ctc_topo_P()
@ -93,6 +97,7 @@ class MmiTrainingGraphCompiler(object):
).invert()
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
logging.info(f"ctc_topo_P num_arcs: {self.ctc_topo_P.num_arcs}")
def compile(
self, texts: Iterable[str], replicate_den: bool = True