mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Minor fixes.
This commit is contained in:
parent
6f5d63492a
commit
9e6bd0f07c
@ -99,7 +99,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""Return a dict containing training parameters.
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
All training related parameters that are not passed from the commandline
|
All training related parameters that are not passed from the commandline
|
||||||
is saved in the variable `params`.
|
are saved in the variable `params`.
|
||||||
|
|
||||||
Commandline options are merged into `params` after they are parsed, so
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
you can also access them via `params`.
|
you can also access them via `params`.
|
||||||
|
@ -1,4 +1,21 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -11,21 +28,20 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from tdnn_lstm_ctc.model import TdnnLstm
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam
|
||||||
|
|
||||||
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.mmi import LFMMILoss
|
from icefall.mmi import LFMMILoss
|
||||||
|
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
@ -61,28 +77,22 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ali-model",
|
"--num-epochs",
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py "
|
|
||||||
"and you have some checkpoints inside the directory "
|
|
||||||
"tdnn_lstm_ctc/exp_bpe_500 ."
|
|
||||||
"It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt "
|
|
||||||
"as the pre-trained alignment model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ali-model-epoch",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=19,
|
default=50,
|
||||||
help="If --use-ali-model is True, load "
|
help="Number of epochs to train.",
|
||||||
"tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as "
|
)
|
||||||
"the alignment model."
|
|
||||||
"Used only if --use-ali-model is True.",
|
parser.add_argument(
|
||||||
|
"--start-epoch",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""Resume training from from this epoch.
|
||||||
|
If it is positive, it will load checkpoint from
|
||||||
|
conformer_mmi/exp/epoch-{start_epoch-1}.pt
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: add extra arguments and support DDP training.
|
|
||||||
# Currently, only single GPU training is implemented. Will add
|
|
||||||
# DDP training once single GPU training is finished.
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -90,7 +100,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""Return a dict containing training parameters.
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
All training related parameters that are not passed from the commandline
|
All training related parameters that are not passed from the commandline
|
||||||
is saved in the variable `params`.
|
are saved in the variable `params`.
|
||||||
|
|
||||||
Commandline options are merged into `params` after they are parsed, so
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
you can also access them via `params`.
|
you can also access them via `params`.
|
||||||
@ -103,20 +113,6 @@ def get_params() -> AttributeDict:
|
|||||||
- lang_dir: It contains language related input files such as
|
- lang_dir: It contains language related input files such as
|
||||||
"lexicon.txt"
|
"lexicon.txt"
|
||||||
|
|
||||||
- lr: It specifies the initial learning rate
|
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
|
||||||
in computing features.
|
|
||||||
|
|
||||||
- weight_decay: The weight_decay for the optimizer.
|
|
||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
|
||||||
|
|
||||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
|
||||||
and continue training from that checkpoint.
|
|
||||||
|
|
||||||
- num_epochs: Number of epochs to train.
|
|
||||||
|
|
||||||
- best_train_loss: Best training loss so far. It is used to select
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
the model that has the lowest training loss. It is
|
the model that has the lowest training loss. It is
|
||||||
updated during the training.
|
updated during the training.
|
||||||
@ -135,36 +131,60 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||||
|
|
||||||
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
|
||||||
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
|
- use_feat_batchnorm: Whether to do batch normalization for the
|
||||||
|
input features.
|
||||||
|
|
||||||
|
- attention_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
|
- head: Number of heads of multi-head attention model.
|
||||||
|
|
||||||
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
|
- weight_decay: The weight_decay for the optimizer.
|
||||||
|
|
||||||
|
- lr_factor: The lr_factor for Noam optimizer.
|
||||||
|
|
||||||
|
- warm_step: The warm_step for Noam optimizer.
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("conformer_mmi/exp_500"),
|
"exp_dir": Path("conformer_mmi/exp_500"),
|
||||||
"lang_dir": Path("data/lang_bpe_500"),
|
"lang_dir": Path("data/lang_bpe_500"),
|
||||||
"feature_dim": 80,
|
|
||||||
"weight_decay": 1e-6,
|
|
||||||
"subsampling_factor": 4,
|
|
||||||
"start_epoch": 0,
|
|
||||||
"num_epochs": 50,
|
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 10,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 10,
|
"valid_interval": 3000,
|
||||||
"use_pruned_intersect": False,
|
# parameters for conformer
|
||||||
"den_scale": 1.0,
|
"feature_dim": 80,
|
||||||
#
|
"subsampling_factor": 4,
|
||||||
"att_rate": 0.7,
|
"use_feat_batchnorm": True,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"num_decoder_layers": 6,
|
"num_decoder_layers": 6,
|
||||||
"is_espnet_structure": True,
|
# parameters for loss
|
||||||
"use_feat_batchnorm": True,
|
"beam_size": 10,
|
||||||
|
"reduction": "sum",
|
||||||
|
"use_double_scores": True,
|
||||||
|
"att_rate": 0.7,
|
||||||
|
# parameters for Noam
|
||||||
|
"weight_decay": 1e-6,
|
||||||
"lr_factor": 5.0,
|
"lr_factor": 5.0,
|
||||||
"warm_step": 80000,
|
"warm_step": 80000,
|
||||||
|
"use_pruned_intersect": False,
|
||||||
|
"den_scale": 1.0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -261,13 +281,12 @@ def save_checkpoint(
|
|||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
ali_model: Optional[nn.Module],
|
|
||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: BpeMmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute MMI loss given the model and its inputs.
|
Compute LF-MMI loss given the model and its inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
@ -278,7 +297,9 @@ def compute_loss(
|
|||||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
for the content in it.
|
for the content in it.
|
||||||
graph_compiler:
|
graph_compiler:
|
||||||
It is used to build num_graphs and den_graphs.
|
It is used to build a decoding graph from a ctc topo and training
|
||||||
|
transcript. The training transcript is contained in the given `batch`,
|
||||||
|
while the ctc topo is built when this compiler is instantiated.
|
||||||
is_training:
|
is_training:
|
||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
@ -286,54 +307,34 @@ def compute_loss(
|
|||||||
"""
|
"""
|
||||||
device = graph_compiler.device
|
device = graph_compiler.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is [N, T, C]
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||||
# nnet_output is [N, T, C]
|
# nnet_output is (N, T, C)
|
||||||
if ali_model is not None and params.batch_idx_train < 4000:
|
|
||||||
feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T]
|
|
||||||
ali_model_output = ali_model(feature)
|
|
||||||
# subsampling is done slightly differently, may be small length
|
|
||||||
# differences.
|
|
||||||
min_len = min(ali_model_output.shape[1], nnet_output.shape[1])
|
|
||||||
# scale less than one so it will be encouraged
|
|
||||||
# to mimic ali_model's output
|
|
||||||
ali_model_scale = 500.0 / (params.batch_idx_train + 500)
|
|
||||||
|
|
||||||
# Use clone() here or log-softmax backprop will fail.
|
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||||
nnet_output = nnet_output.clone()
|
# different duration in decreasing order, required by
|
||||||
|
# `k2.intersect_dense` called in `LFMMILoss.forward()`
|
||||||
|
supervision_segments, texts = encode_supervisions(
|
||||||
|
supervisions, subsampling_factor=params.subsampling_factor
|
||||||
|
)
|
||||||
|
|
||||||
nnet_output[:, :min_len, :] += (
|
loss_fn = LFMMILoss(
|
||||||
ali_model_scale * ali_model_output[:, :min_len, :]
|
graph_compiler=graph_compiler,
|
||||||
)
|
use_pruned_intersect=params.use_pruned_intersect,
|
||||||
|
den_scale=params.den_scale,
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
# different duration in decreasing order, required by
|
nnet_output,
|
||||||
# `k2.intersect_dense` called in LFMMILoss
|
supervision_segments,
|
||||||
#
|
allow_truncate=params.subsampling_factor - 1,
|
||||||
# TODO: If params.use_pruned_intersect is True, there is no
|
)
|
||||||
# need to call encode_supervisions
|
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||||
supervision_segments, texts = encode_supervisions(
|
|
||||||
supervisions, subsampling_factor=params.subsampling_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
|
||||||
nnet_output,
|
|
||||||
supervision_segments,
|
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_fn = LFMMILoss(
|
|
||||||
graph_compiler=graph_compiler,
|
|
||||||
den_scale=params.den_scale,
|
|
||||||
use_pruned_intersect=params.use_pruned_intersect,
|
|
||||||
)
|
|
||||||
|
|
||||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
|
||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
token_ids = graph_compiler.texts_to_ids(texts)
|
||||||
@ -373,8 +374,7 @@ def compute_loss(
|
|||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
ali_model: Optional[nn.Module],
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
graph_compiler: BpeMmiTrainingGraphCompiler,
|
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -391,7 +391,6 @@ def compute_validation_loss(
|
|||||||
loss, mmi_loss, att_loss = compute_loss(
|
loss, mmi_loss, att_loss = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
ali_model=ali_model,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
@ -432,9 +431,8 @@ def compute_validation_loss(
|
|||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
ali_model: Optional[nn.Module],
|
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
graph_compiler: BpeMmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
@ -451,9 +449,6 @@ def train_one_epoch(
|
|||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The model for training.
|
The model for training.
|
||||||
ali_model:
|
|
||||||
The force alignment model for training. It is from
|
|
||||||
tdnn_lstm_ctc/train_bpe.py
|
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer we are using.
|
The optimizer we are using.
|
||||||
graph_compiler:
|
graph_compiler:
|
||||||
@ -483,7 +478,6 @@ def train_one_epoch(
|
|||||||
loss, mmi_loss, att_loss = compute_loss(
|
loss, mmi_loss, att_loss = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
ali_model=ali_model,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
@ -494,7 +488,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
loss_cpu = loss.detach().cpu().item()
|
||||||
@ -519,7 +513,7 @@ def train_one_epoch(
|
|||||||
f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, "
|
f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, "
|
||||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
||||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||||
f"total avg mmi loss: {tot_avg_mmi_loss:.4f}, "
|
f"total avg mmiloss: {tot_avg_mmi_loss:.4f}, "
|
||||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
f"batch size: {batch_size}"
|
f"batch size: {batch_size}"
|
||||||
@ -568,7 +562,6 @@ def train_one_epoch(
|
|||||||
compute_validation_loss(
|
compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
ali_model=ali_model,
|
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
@ -576,10 +569,10 @@ def train_one_epoch(
|
|||||||
model.train()
|
model.train()
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"valid mmi loss {params.valid_mmi_loss:.4f}, "
|
f"valid mmi loss {params.valid_mmi_loss:.4f},"
|
||||||
f"valid att loss {params.valid_att_loss:.4f}, "
|
f"valid att loss {params.valid_att_loss:.4f},"
|
||||||
f"valid loss {params.valid_loss:.4f}, "
|
f"valid loss {params.valid_loss:.4f},"
|
||||||
f"best valid loss: {params.best_valid_loss:.4f}, "
|
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||||
f"best valid epoch: {params.best_valid_epoch}"
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
)
|
)
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -642,11 +635,13 @@ def run(rank, world_size, args):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = BpeMmiTrainingGraphCompiler(
|
graph_compiler = MmiTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir,
|
||||||
|
uniq_filename="lexicon.txt",
|
||||||
device=device,
|
device=device,
|
||||||
sos_token="<sos/eos>",
|
oov="<UNK>",
|
||||||
eos_token="<sos/eos>",
|
sos_id=1,
|
||||||
|
eos_id=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -658,7 +653,6 @@ def run(rank, world_size, args):
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
num_decoder_layers=params.num_decoder_layers,
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
is_espnet_structure=params.is_espnet_structure,
|
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -679,32 +673,6 @@ def run(rank, world_size, args):
|
|||||||
if checkpoints:
|
if checkpoints:
|
||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
if args.use_ali_model:
|
|
||||||
ali_model = TdnnLstm(
|
|
||||||
num_features=params.feature_dim,
|
|
||||||
num_classes=num_classes,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
ali_model_fname = Path(
|
|
||||||
f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt"
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
ali_model_fname.is_file()
|
|
||||||
), f"ali model filename {ali_model_fname} does not exist!"
|
|
||||||
|
|
||||||
ali_model.load_state_dict(
|
|
||||||
torch.load(ali_model_fname, map_location="cpu")["model"]
|
|
||||||
)
|
|
||||||
ali_model.to(device)
|
|
||||||
|
|
||||||
ali_model.eval()
|
|
||||||
ali_model.requires_grad_(False)
|
|
||||||
logging.info(f"Use ali_model: {ali_model_fname}")
|
|
||||||
else:
|
|
||||||
ali_model = None
|
|
||||||
logging.info("No ali_model")
|
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
train_dl = librispeech.train_dataloaders()
|
train_dl = librispeech.train_dataloaders()
|
||||||
valid_dl = librispeech.valid_dataloaders()
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
@ -727,7 +695,6 @@ def run(rank, world_size, args):
|
|||||||
train_one_epoch(
|
train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
ali_model=ali_model,
|
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
|
@ -227,5 +227,3 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
./local/compile_hlg.py --lang-dir $lang_dir
|
./local/compile_hlg.py --lang-dir $lang_dir
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd data && ln -sfv lang_bpe_500 lang_bpe
|
|
||||||
|
@ -4,13 +4,13 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler
|
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||||
|
|
||||||
|
|
||||||
def _compute_mmi_loss_exact_optimized(
|
def _compute_mmi_loss_exact_optimized(
|
||||||
dense_fsa_vec: k2.DenseFsaVec,
|
dense_fsa_vec: k2.DenseFsaVec,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
graph_compiler: BpeMmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
den_scale: float = 1.0,
|
den_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -98,7 +98,7 @@ def _compute_mmi_loss_exact_optimized(
|
|||||||
def _compute_mmi_loss_exact_non_optimized(
|
def _compute_mmi_loss_exact_non_optimized(
|
||||||
dense_fsa_vec: k2.DenseFsaVec,
|
dense_fsa_vec: k2.DenseFsaVec,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
graph_compiler: BpeMmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
den_scale: float = 1.0,
|
den_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -133,7 +133,7 @@ def _compute_mmi_loss_exact_non_optimized(
|
|||||||
def _compute_mmi_loss_pruned(
|
def _compute_mmi_loss_pruned(
|
||||||
dense_fsa_vec: k2.DenseFsaVec,
|
dense_fsa_vec: k2.DenseFsaVec,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
graph_compiler: BpeMmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
den_scale: float = 1.0,
|
den_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -184,7 +184,7 @@ class LFMMILoss(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph_compiler: BpeMmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
use_pruned_intersect: bool = False,
|
use_pruned_intersect: bool = False,
|
||||||
den_scale: float = 1.0,
|
den_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
|
@ -15,6 +15,8 @@ class MmiTrainingGraphCompiler(object):
|
|||||||
uniq_filename: str = "uniq_lexicon.txt",
|
uniq_filename: str = "uniq_lexicon.txt",
|
||||||
device: Union[str, torch.device] = "cpu",
|
device: Union[str, torch.device] = "cpu",
|
||||||
oov: str = "<UNK>",
|
oov: str = "<UNK>",
|
||||||
|
sos_id: int = 1,
|
||||||
|
eos_id: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -45,6 +47,8 @@ class MmiTrainingGraphCompiler(object):
|
|||||||
self.L_inv = self.lexicon.L_inv.to(self.device)
|
self.L_inv = self.lexicon.L_inv.to(self.device)
|
||||||
|
|
||||||
self.oov_id = self.lexicon.word_table[oov]
|
self.oov_id = self.lexicon.word_table[oov]
|
||||||
|
self.sos_id = sos_id
|
||||||
|
self.eos_id = eos_id
|
||||||
|
|
||||||
self.build_ctc_topo_P()
|
self.build_ctc_topo_P()
|
||||||
|
|
||||||
@ -93,6 +97,7 @@ class MmiTrainingGraphCompiler(object):
|
|||||||
).invert()
|
).invert()
|
||||||
|
|
||||||
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
|
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
|
||||||
|
logging.info(f"ctc_topo_P num_arcs: {self.ctc_topo_P.num_arcs}")
|
||||||
|
|
||||||
def compile(
|
def compile(
|
||||||
self, texts: Iterable[str], replicate_den: bool = True
|
self, texts: Iterable[str], replicate_den: bool = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user