Support using attention decoder in MMI training.

This commit is contained in:
Fangjun Kuang 2021-09-11 14:16:45 +08:00
parent 78e1fdc994
commit 4f3a53fc41
6 changed files with 80 additions and 29 deletions

View File

@ -85,7 +85,7 @@ def get_parser():
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
help="""Resume training from this epoch.
If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt
""",

View File

@ -57,7 +57,7 @@ class Conformer(Transformer):
normalize_before: bool = True,
vgg_frontend: bool = False,
is_espnet_structure: bool = False,
mmi_loss: bool = True,
is_bpe: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
@ -72,7 +72,7 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
mmi_loss=mmi_loss,
is_bpe=is_bpe,
use_feat_batchnorm=use_feat_batchnorm,
)

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
@ -59,6 +75,23 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=50,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from this epoch.
If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
return parser
@ -88,11 +121,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -120,8 +148,6 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 50,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -130,13 +156,14 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000,
"beam_size": 10,
"use_pruned_intersect": False,
"den_scale": 1.0,
#
"att_rate": 0, # If not zero, use attention decoder
"att_rate": 0.7, # If not zero, use attention decoder
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 0,
"num_decoder_layers": 6,
"is_espnet_structure": True,
"use_feat_batchnorm": True,
"lr_factor": 5.0,
@ -288,15 +315,14 @@ def compute_loss(
loss_fn = LFMMILoss(
graph_compiler=graph_compiler,
output_beam=params.beam_size,
den_scale=params.den_scale,
use_pruned_intersect=params.use_pruned_intersect,
)
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
assert params.att_rate == 0
if params.att_rate != 0.0:
# TODO: not working
token_ids = graph_compiler.texts_to_ids(texts)
with torch.set_grad_enabled(is_training):
if hasattr(model, "module"):
@ -304,16 +330,16 @@ def compute_loss(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
sos_id=params.sos_id,
eos_id=params.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
sos_id=params.sos_id,
eos_id=params.eos_id,
)
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
else:
@ -587,7 +613,6 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
@ -617,12 +642,19 @@ def run(rank, world_size, args):
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
is_bpe=False,
is_espnet_structure=params.is_espnet_structure,
use_feat_batchnorm=params.use_feat_batchnorm,
)
assert model.decoder_num_class == num_classes + 1
params.sos_id = num_classes
params.eos_id = num_classes
checkpoints = load_checkpoint_if_available(params=params, model=model)
logging.info(params)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])

View File

@ -42,7 +42,7 @@ class Transformer(nn.Module):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
mmi_loss: bool = True,
is_bpe: bool = True,
use_feat_batchnorm: bool = False,
) -> None:
"""
@ -71,7 +71,9 @@ class Transformer(nn.Module):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
mmi_loss:
is_bpe:
True if the modeling unit is word pieces which has already included
SOS/EOS IDs.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
@ -123,7 +125,7 @@ class Transformer(nn.Module):
)
if num_decoder_layers > 0:
if mmi_loss:
if is_bpe is False:
self.decoder_num_class = (
self.num_classes + 1
) # +1 for the sos/eos symbol

View File

@ -11,6 +11,7 @@ def _compute_mmi_loss_exact_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
@ -79,7 +80,7 @@ def _compute_mmi_loss_exact_optimized(
num_den_lats = k2.intersect_dense(
num_den_reordered_graphs,
dense_fsa_vec,
output_beam=8.0,
output_beam=output_beam,
a_to_b_map=a_to_b_map,
)
@ -99,6 +100,7 @@ def _compute_mmi_loss_exact_non_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
@ -112,9 +114,12 @@ def _compute_mmi_loss_exact_non_optimized(
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
# TODO: pass output_beam as function argument
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=8.0)
num_lats = k2.intersect_dense(
num_graphs, dense_fsa_vec, output_beam=output_beam
)
den_lats = k2.intersect_dense(
den_graphs, dense_fsa_vec, output_beam=output_beam
)
num_tot_scores = num_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
@ -134,6 +139,7 @@ def _compute_mmi_loss_pruned(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0,
) -> torch.Tensor:
"""
@ -148,7 +154,9 @@ def _compute_mmi_loss_pruned(
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
num_lats = k2.intersect_dense(
num_graphs, dense_fsa_vec, output_beam=output_beam
)
# the values for search_beam/output_beam/min_active_states/max_active_states
# are not tuned. You may want to tune them.
@ -156,7 +164,7 @@ def _compute_mmi_loss_pruned(
den_graphs,
dense_fsa_vec,
search_beam=20.0,
output_beam=8.0,
output_beam=output_beam,
min_active_states=30,
max_active_states=10000,
)
@ -185,13 +193,15 @@ class LFMMILoss(nn.Module):
def __init__(
self,
graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
use_pruned_intersect: bool = False,
den_scale: float = 1.0,
):
super().__init__()
self.graph_compiler = graph_compiler
self.den_scale = den_scale
self.output_beam = output_beam
self.use_pruned_intersect = use_pruned_intersect
self.den_scale = den_scale
def forward(
self,
@ -218,5 +228,6 @@ class LFMMILoss(nn.Module):
dense_fsa_vec=dense_fsa_vec,
texts=texts,
graph_compiler=self.graph_compiler,
output_beam=self.output_beam,
den_scale=self.den_scale,
)

View File

@ -155,6 +155,12 @@ def mmi_graph_compiler_test():
)
den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo")
texts = ["cat at cat", "ac at ca"]
token_ids = graph_compiler.texts_to_ids(texts)
# c a t a t c a t a c a t SPN
expected_token_ids = [[3, 2, 4, 2, 4, 3, 2, 4], [2, 3, 2, 4, 1]]
assert token_ids == expected_token_ids
def test_main():
generate_test_data()