Support using attention decoder in MMI training.

This commit is contained in:
Fangjun Kuang 2021-09-11 14:16:45 +08:00
parent 78e1fdc994
commit 4f3a53fc41
6 changed files with 80 additions and 29 deletions

View File

@ -85,7 +85,7 @@ def get_parser():
"--start-epoch", "--start-epoch",
type=int, type=int,
default=0, default=0,
help="""Resume training from from this epoch. help="""Resume training from this epoch.
If it is positive, it will load checkpoint from If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt conformer_ctc/exp/epoch-{start_epoch-1}.pt
""", """,

View File

@ -57,7 +57,7 @@ class Conformer(Transformer):
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
is_espnet_structure: bool = False, is_espnet_structure: bool = False,
mmi_loss: bool = True, is_bpe: bool = True,
use_feat_batchnorm: bool = False, use_feat_batchnorm: bool = False,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
@ -72,7 +72,7 @@ class Conformer(Transformer):
dropout=dropout, dropout=dropout,
normalize_before=normalize_before, normalize_before=normalize_before,
vgg_frontend=vgg_frontend, vgg_frontend=vgg_frontend,
mmi_loss=mmi_loss, is_bpe=is_bpe,
use_feat_batchnorm=use_feat_batchnorm, use_feat_batchnorm=use_feat_batchnorm,
) )

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import logging import logging
@ -59,6 +75,23 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--num-epochs",
type=int,
default=50,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from this epoch.
If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
return parser return parser
@ -88,11 +121,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select - best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is the model that has the lowest training loss. It is
updated during the training. updated during the training.
@ -120,8 +148,6 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 1e-6, "weight_decay": 1e-6,
"subsampling_factor": 4, "subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 50,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
@ -130,13 +156,14 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, "valid_interval": 3000,
"beam_size": 10,
"use_pruned_intersect": False, "use_pruned_intersect": False,
"den_scale": 1.0, "den_scale": 1.0,
# #
"att_rate": 0, # If not zero, use attention decoder "att_rate": 0.7, # If not zero, use attention decoder
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_decoder_layers": 0, "num_decoder_layers": 6,
"is_espnet_structure": True, "is_espnet_structure": True,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"lr_factor": 5.0, "lr_factor": 5.0,
@ -288,15 +315,14 @@ def compute_loss(
loss_fn = LFMMILoss( loss_fn = LFMMILoss(
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
output_beam=params.beam_size,
den_scale=params.den_scale, den_scale=params.den_scale,
use_pruned_intersect=params.use_pruned_intersect, use_pruned_intersect=params.use_pruned_intersect,
) )
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
assert params.att_rate == 0
if params.att_rate != 0.0: if params.att_rate != 0.0:
# TODO: not working
token_ids = graph_compiler.texts_to_ids(texts) token_ids = graph_compiler.texts_to_ids(texts)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if hasattr(model, "module"): if hasattr(model, "module"):
@ -304,16 +330,16 @@ def compute_loss(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
token_ids=token_ids, token_ids=token_ids,
sos_id=graph_compiler.sos_id, sos_id=params.sos_id,
eos_id=graph_compiler.eos_id, eos_id=params.eos_id,
) )
else: else:
att_loss = model.decoder_forward( att_loss = model.decoder_forward(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
token_ids=token_ids, token_ids=token_ids,
sos_id=graph_compiler.sos_id, sos_id=params.sos_id,
eos_id=graph_compiler.eos_id, eos_id=params.eos_id,
) )
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
else: else:
@ -587,7 +613,6 @@ def run(rank, world_size, args):
setup_logger(f"{params.exp_dir}/log/log-train") setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started") logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0: if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
@ -617,12 +642,19 @@ def run(rank, world_size, args):
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False, vgg_frontend=False,
is_bpe=False,
is_espnet_structure=params.is_espnet_structure, is_espnet_structure=params.is_espnet_structure,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )
assert model.decoder_num_class == num_classes + 1
params.sos_id = num_classes
params.eos_id = num_classes
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)
logging.info(params)
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])

View File

@ -42,7 +42,7 @@ class Transformer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
mmi_loss: bool = True, is_bpe: bool = True,
use_feat_batchnorm: bool = False, use_feat_batchnorm: bool = False,
) -> None: ) -> None:
""" """
@ -71,7 +71,9 @@ class Transformer(nn.Module):
If True, use pre-layer norm; False to use post-layer norm. If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend: vgg_frontend:
True to use vgg style frontend for subsampling. True to use vgg style frontend for subsampling.
mmi_loss: is_bpe:
True if the modeling unit is word pieces which has already included
SOS/EOS IDs.
use_feat_batchnorm: use_feat_batchnorm:
True to use batchnorm for the input layer. True to use batchnorm for the input layer.
""" """
@ -123,7 +125,7 @@ class Transformer(nn.Module):
) )
if num_decoder_layers > 0: if num_decoder_layers > 0:
if mmi_loss: if is_bpe is False:
self.decoder_num_class = ( self.decoder_num_class = (
self.num_classes + 1 self.num_classes + 1
) # +1 for the sos/eos symbol ) # +1 for the sos/eos symbol

View File

@ -11,6 +11,7 @@ def _compute_mmi_loss_exact_optimized(
dense_fsa_vec: k2.DenseFsaVec, dense_fsa_vec: k2.DenseFsaVec,
texts: List[str], texts: List[str],
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0, den_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -79,7 +80,7 @@ def _compute_mmi_loss_exact_optimized(
num_den_lats = k2.intersect_dense( num_den_lats = k2.intersect_dense(
num_den_reordered_graphs, num_den_reordered_graphs,
dense_fsa_vec, dense_fsa_vec,
output_beam=8.0, output_beam=output_beam,
a_to_b_map=a_to_b_map, a_to_b_map=a_to_b_map,
) )
@ -99,6 +100,7 @@ def _compute_mmi_loss_exact_non_optimized(
dense_fsa_vec: k2.DenseFsaVec, dense_fsa_vec: k2.DenseFsaVec,
texts: List[str], texts: List[str],
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0, den_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -112,9 +114,12 @@ def _compute_mmi_loss_exact_non_optimized(
""" """
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
# TODO: pass output_beam as function argument num_lats = k2.intersect_dense(
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0) num_graphs, dense_fsa_vec, output_beam=output_beam
den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=8.0) )
den_lats = k2.intersect_dense(
den_graphs, dense_fsa_vec, output_beam=output_beam
)
num_tot_scores = num_lats.get_tot_scores( num_tot_scores = num_lats.get_tot_scores(
log_semiring=True, use_double_scores=True log_semiring=True, use_double_scores=True
@ -134,6 +139,7 @@ def _compute_mmi_loss_pruned(
dense_fsa_vec: k2.DenseFsaVec, dense_fsa_vec: k2.DenseFsaVec,
texts: List[str], texts: List[str],
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
den_scale: float = 1.0, den_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -148,7 +154,9 @@ def _compute_mmi_loss_pruned(
""" """
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False) num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) num_lats = k2.intersect_dense(
num_graphs, dense_fsa_vec, output_beam=output_beam
)
# the values for search_beam/output_beam/min_active_states/max_active_states # the values for search_beam/output_beam/min_active_states/max_active_states
# are not tuned. You may want to tune them. # are not tuned. You may want to tune them.
@ -156,7 +164,7 @@ def _compute_mmi_loss_pruned(
den_graphs, den_graphs,
dense_fsa_vec, dense_fsa_vec,
search_beam=20.0, search_beam=20.0,
output_beam=8.0, output_beam=output_beam,
min_active_states=30, min_active_states=30,
max_active_states=10000, max_active_states=10000,
) )
@ -185,13 +193,15 @@ class LFMMILoss(nn.Module):
def __init__( def __init__(
self, self,
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
output_beam: float,
use_pruned_intersect: bool = False, use_pruned_intersect: bool = False,
den_scale: float = 1.0, den_scale: float = 1.0,
): ):
super().__init__() super().__init__()
self.graph_compiler = graph_compiler self.graph_compiler = graph_compiler
self.den_scale = den_scale self.output_beam = output_beam
self.use_pruned_intersect = use_pruned_intersect self.use_pruned_intersect = use_pruned_intersect
self.den_scale = den_scale
def forward( def forward(
self, self,
@ -218,5 +228,6 @@ class LFMMILoss(nn.Module):
dense_fsa_vec=dense_fsa_vec, dense_fsa_vec=dense_fsa_vec,
texts=texts, texts=texts,
graph_compiler=self.graph_compiler, graph_compiler=self.graph_compiler,
output_beam=self.output_beam,
den_scale=self.den_scale, den_scale=self.den_scale,
) )

View File

@ -155,6 +155,12 @@ def mmi_graph_compiler_test():
) )
den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo") den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo")
texts = ["cat at cat", "ac at ca"]
token_ids = graph_compiler.texts_to_ids(texts)
# c a t a t c a t a c a t SPN
expected_token_ids = [[3, 2, 4, 2, 4, 3, 2, 4], [2, 3, 2, 4, 1]]
assert token_ids == expected_token_ids
def test_main(): def test_main():
generate_test_data() generate_test_data()