Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-09-24 20:34:34 +08:00
parent 6f5d63492a
commit 9e6bd0f07c
5 changed files with 119 additions and 149 deletions

View File

@ -99,7 +99,7 @@ def get_params() -> AttributeDict:
"""Return a dict containing training parameters. """Return a dict containing training parameters.
All training related parameters that are not passed from the commandline All training related parameters that are not passed from the commandline
is saved in the variable `params`. are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`. you can also access them via `params`.

View File

@ -1,4 +1,21 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import logging import logging
@ -11,21 +28,20 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from tdnn_lstm_ctc.model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.mmi import LFMMILoss from icefall.mmi import LFMMILoss
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
encode_supervisions, encode_supervisions,
@ -61,28 +77,22 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--use-ali-model", "--num-epochs",
type=str2bool,
default=True,
help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py "
"and you have some checkpoints inside the directory "
"tdnn_lstm_ctc/exp_bpe_500 ."
"It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt "
"as the pre-trained alignment model",
)
parser.add_argument(
"--ali-model-epoch",
type=int, type=int,
default=19, default=50,
help="If --use-ali-model is True, load " help="Number of epochs to train.",
"tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as " )
"the alignment model."
"Used only if --use-ali-model is True.", parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
conformer_mmi/exp/epoch-{start_epoch-1}.pt
""",
) )
# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser return parser
@ -90,7 +100,7 @@ def get_params() -> AttributeDict:
"""Return a dict containing training parameters. """Return a dict containing training parameters.
All training related parameters that are not passed from the commandline All training related parameters that are not passed from the commandline
is saved in the variable `params`. are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`. you can also access them via `params`.
@ -103,20 +113,6 @@ def get_params() -> AttributeDict:
- lang_dir: It contains language related input files such as - lang_dir: It contains language related input files such as
"lexicon.txt" "lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select - best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is the model that has the lowest training loss. It is
updated during the training. updated during the training.
@ -135,36 +131,60 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0 - log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0 - reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- head: Number of heads of multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- lr_factor: The lr_factor for Noam optimizer.
- warm_step: The warm_step for Noam optimizer.
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_mmi/exp_500"), "exp_dir": Path("conformer_mmi/exp_500"),
"lang_dir": Path("data/lang_bpe_500"), "lang_dir": Path("data/lang_bpe_500"),
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 50,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 10, "valid_interval": 3000,
"use_pruned_intersect": False, # parameters for conformer
"den_scale": 1.0, "feature_dim": 80,
# "subsampling_factor": 4,
"att_rate": 0.7, "use_feat_batchnorm": True,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_decoder_layers": 6, "num_decoder_layers": 6,
"is_espnet_structure": True, # parameters for loss
"use_feat_batchnorm": True, "beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
"att_rate": 0.7,
# parameters for Noam
"weight_decay": 1e-6,
"lr_factor": 5.0, "lr_factor": 5.0,
"warm_step": 80000, "warm_step": 80000,
"use_pruned_intersect": False,
"den_scale": 1.0,
} }
) )
@ -261,13 +281,12 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module],
batch: dict, batch: dict,
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
is_training: bool, is_training: bool,
): ):
""" """
Compute MMI loss given the model and its inputs. Compute LF-MMI loss given the model and its inputs.
Args: Args:
params: params:
@ -278,7 +297,9 @@ def compute_loss(
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it. for the content in it.
graph_compiler: graph_compiler:
It is used to build num_graphs and den_graphs. It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training: is_training:
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
@ -286,53 +307,33 @@ def compute_loss(
""" """
device = graph_compiler.device device = graph_compiler.device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is [N, T, C] # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is (N, T, C)
if ali_model is not None and params.batch_idx_train < 4000:
feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T]
ali_model_output = ali_model(feature)
# subsampling is done slightly differently, may be small length
# differences.
min_len = min(ali_model_output.shape[1], nnet_output.shape[1])
# scale less than one so it will be encouraged
# to mimic ali_model's output
ali_model_scale = 500.0 / (params.batch_idx_train + 500)
# Use clone() here or log-softmax backprop will fail.
nnet_output = nnet_output.clone()
nnet_output[:, :min_len, :] += (
ali_model_scale * ali_model_output[:, :min_len, :]
)
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by
# `k2.intersect_dense` called in LFMMILoss # `k2.intersect_dense` called in `LFMMILoss.forward()`
#
# TODO: If params.use_pruned_intersect is True, there is no
# need to call encode_supervisions
supervision_segments, texts = encode_supervisions( supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor supervisions, subsampling_factor=params.subsampling_factor
) )
loss_fn = LFMMILoss(
graph_compiler=graph_compiler,
use_pruned_intersect=params.use_pruned_intersect,
den_scale=params.den_scale,
)
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output,
supervision_segments, supervision_segments,
allow_truncate=params.subsampling_factor - 1, allow_truncate=params.subsampling_factor - 1,
) )
loss_fn = LFMMILoss(
graph_compiler=graph_compiler,
den_scale=params.den_scale,
use_pruned_intersect=params.use_pruned_intersect,
)
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
if params.att_rate != 0.0: if params.att_rate != 0.0:
@ -373,8 +374,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module], graph_compiler: MmiTrainingGraphCompiler,
graph_compiler: BpeMmiTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> None: ) -> None:
@ -391,7 +391,6 @@ def compute_validation_loss(
loss, mmi_loss, att_loss = compute_loss( loss, mmi_loss, att_loss = compute_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=False, is_training=False,
@ -432,9 +431,8 @@ def compute_validation_loss(
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
@ -451,9 +449,6 @@ def train_one_epoch(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The model for training. The model for training.
ali_model:
The force alignment model for training. It is from
tdnn_lstm_ctc/train_bpe.py
optimizer: optimizer:
The optimizer we are using. The optimizer we are using.
graph_compiler: graph_compiler:
@ -483,7 +478,6 @@ def train_one_epoch(
loss, mmi_loss, att_loss = compute_loss( loss, mmi_loss, att_loss = compute_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
@ -494,7 +488,7 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item() loss_cpu = loss.detach().cpu().item()
@ -568,7 +562,6 @@ def train_one_epoch(
compute_validation_loss( compute_validation_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
@ -579,7 +572,7 @@ def train_one_epoch(
f"valid mmi loss {params.valid_mmi_loss:.4f}," f"valid mmi loss {params.valid_mmi_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f}," f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f}," f"valid loss {params.valid_loss:.4f},"
f"best valid loss: {params.best_valid_loss:.4f}, " f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}" f"best valid epoch: {params.best_valid_epoch}"
) )
if tb_writer is not None: if tb_writer is not None:
@ -642,11 +635,13 @@ def run(rank, world_size, args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
graph_compiler = BpeMmiTrainingGraphCompiler( graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
uniq_filename="lexicon.txt",
device=device, device=device,
sos_token="<sos/eos>", oov="<UNK>",
eos_token="<sos/eos>", sos_id=1,
eos_id=1,
) )
logging.info("About to create model") logging.info("About to create model")
@ -658,7 +653,6 @@ def run(rank, world_size, args):
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False, vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
@ -679,32 +673,6 @@ def run(rank, world_size, args):
if checkpoints: if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
if args.use_ali_model:
ali_model = TdnnLstm(
num_features=params.feature_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
)
ali_model_fname = Path(
f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt"
)
assert (
ali_model_fname.is_file()
), f"ali model filename {ali_model_fname} does not exist!"
ali_model.load_state_dict(
torch.load(ali_model_fname, map_location="cpu")["model"]
)
ali_model.to(device)
ali_model.eval()
ali_model.requires_grad_(False)
logging.info(f"Use ali_model: {ali_model_fname}")
else:
ali_model = None
logging.info("No ali_model")
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders() train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() valid_dl = librispeech.valid_dataloaders()
@ -727,7 +695,6 @@ def run(rank, world_size, args):
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
optimizer=optimizer, optimizer=optimizer,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
train_dl=train_dl, train_dl=train_dl,

View File

@ -227,5 +227,3 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
./local/compile_hlg.py --lang-dir $lang_dir ./local/compile_hlg.py --lang-dir $lang_dir
done done
fi fi
cd data && ln -sfv lang_bpe_500 lang_bpe

View File

@ -4,13 +4,13 @@ import k2
import torch import torch
from torch import nn from torch import nn
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
def _compute_mmi_loss_exact_optimized( def _compute_mmi_loss_exact_optimized(
dense_fsa_vec: k2.DenseFsaVec, dense_fsa_vec: k2.DenseFsaVec,
texts: List[str], texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0, den_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -98,7 +98,7 @@ def _compute_mmi_loss_exact_optimized(
def _compute_mmi_loss_exact_non_optimized( def _compute_mmi_loss_exact_non_optimized(
dense_fsa_vec: k2.DenseFsaVec, dense_fsa_vec: k2.DenseFsaVec,
texts: List[str], texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0, den_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -133,7 +133,7 @@ def _compute_mmi_loss_exact_non_optimized(
def _compute_mmi_loss_pruned( def _compute_mmi_loss_pruned(
dense_fsa_vec: k2.DenseFsaVec, dense_fsa_vec: k2.DenseFsaVec,
texts: List[str], texts: List[str],
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0, den_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -184,7 +184,7 @@ class LFMMILoss(nn.Module):
def __init__( def __init__(
self, self,
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
use_pruned_intersect: bool = False, use_pruned_intersect: bool = False,
den_scale: float = 1.0, den_scale: float = 1.0,
): ):

View File

@ -15,6 +15,8 @@ class MmiTrainingGraphCompiler(object):
uniq_filename: str = "uniq_lexicon.txt", uniq_filename: str = "uniq_lexicon.txt",
device: Union[str, torch.device] = "cpu", device: Union[str, torch.device] = "cpu",
oov: str = "<UNK>", oov: str = "<UNK>",
sos_id: int = 1,
eos_id: int = 1,
): ):
""" """
Args: Args:
@ -45,6 +47,8 @@ class MmiTrainingGraphCompiler(object):
self.L_inv = self.lexicon.L_inv.to(self.device) self.L_inv = self.lexicon.L_inv.to(self.device)
self.oov_id = self.lexicon.word_table[oov] self.oov_id = self.lexicon.word_table[oov]
self.sos_id = sos_id
self.eos_id = eos_id
self.build_ctc_topo_P() self.build_ctc_topo_P()
@ -93,6 +97,7 @@ class MmiTrainingGraphCompiler(object):
).invert() ).invert()
self.ctc_topo_P = k2.arc_sort(ctc_topo_P) self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
logging.info(f"ctc_topo_P num_arcs: {self.ctc_topo_P.num_arcs}")
def compile( def compile(
self, texts: Iterable[str], replicate_den: bool = True self, texts: Iterable[str], replicate_den: bool = True