mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Support using attention decoder in MMI training.
This commit is contained in:
parent
78e1fdc994
commit
4f3a53fc41
@ -85,7 +85,7 @@ def get_parser():
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
help="""Resume training from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
conformer_ctc/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
|
@ -57,7 +57,7 @@ class Conformer(Transformer):
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
is_espnet_structure: bool = False,
|
||||
mmi_loss: bool = True,
|
||||
is_bpe: bool = True,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
@ -72,7 +72,7 @@ class Conformer(Transformer):
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
mmi_loss=mmi_loss,
|
||||
is_bpe=is_bpe,
|
||||
use_feat_batchnorm=use_feat_batchnorm,
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -59,6 +75,23 @@ def get_parser():
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
conformer_ctc/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -88,11 +121,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||
and continue training from that checkpoint.
|
||||
|
||||
- num_epochs: Number of epochs to train.
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
@ -120,8 +148,6 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 1e-6,
|
||||
"subsampling_factor": 4,
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 50,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
@ -130,13 +156,14 @@ def get_params() -> AttributeDict:
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
"beam_size": 10,
|
||||
"use_pruned_intersect": False,
|
||||
"den_scale": 1.0,
|
||||
#
|
||||
"att_rate": 0, # If not zero, use attention decoder
|
||||
"att_rate": 0.7, # If not zero, use attention decoder
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_decoder_layers": 0,
|
||||
"num_decoder_layers": 6,
|
||||
"is_espnet_structure": True,
|
||||
"use_feat_batchnorm": True,
|
||||
"lr_factor": 5.0,
|
||||
@ -288,15 +315,14 @@ def compute_loss(
|
||||
|
||||
loss_fn = LFMMILoss(
|
||||
graph_compiler=graph_compiler,
|
||||
output_beam=params.beam_size,
|
||||
den_scale=params.den_scale,
|
||||
use_pruned_intersect=params.use_pruned_intersect,
|
||||
)
|
||||
|
||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||
|
||||
assert params.att_rate == 0
|
||||
if params.att_rate != 0.0:
|
||||
# TODO: not working
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
@ -304,16 +330,16 @@ def compute_loss(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
||||
else:
|
||||
@ -587,7 +613,6 @@ def run(rank, world_size, args):
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
@ -617,12 +642,19 @@ def run(rank, world_size, args):
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=False,
|
||||
is_bpe=False,
|
||||
is_espnet_structure=params.is_espnet_structure,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
assert model.decoder_num_class == num_classes + 1
|
||||
|
||||
params.sos_id = num_classes
|
||||
params.eos_id = num_classes
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
@ -42,7 +42,7 @@ class Transformer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
mmi_loss: bool = True,
|
||||
is_bpe: bool = True,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
@ -71,7 +71,9 @@ class Transformer(nn.Module):
|
||||
If True, use pre-layer norm; False to use post-layer norm.
|
||||
vgg_frontend:
|
||||
True to use vgg style frontend for subsampling.
|
||||
mmi_loss:
|
||||
is_bpe:
|
||||
True if the modeling unit is word pieces which has already included
|
||||
SOS/EOS IDs.
|
||||
use_feat_batchnorm:
|
||||
True to use batchnorm for the input layer.
|
||||
"""
|
||||
@ -123,7 +125,7 @@ class Transformer(nn.Module):
|
||||
)
|
||||
|
||||
if num_decoder_layers > 0:
|
||||
if mmi_loss:
|
||||
if is_bpe is False:
|
||||
self.decoder_num_class = (
|
||||
self.num_classes + 1
|
||||
) # +1 for the sos/eos symbol
|
||||
|
@ -11,6 +11,7 @@ def _compute_mmi_loss_exact_optimized(
|
||||
dense_fsa_vec: k2.DenseFsaVec,
|
||||
texts: List[str],
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
output_beam: float,
|
||||
den_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -79,7 +80,7 @@ def _compute_mmi_loss_exact_optimized(
|
||||
num_den_lats = k2.intersect_dense(
|
||||
num_den_reordered_graphs,
|
||||
dense_fsa_vec,
|
||||
output_beam=8.0,
|
||||
output_beam=output_beam,
|
||||
a_to_b_map=a_to_b_map,
|
||||
)
|
||||
|
||||
@ -99,6 +100,7 @@ def _compute_mmi_loss_exact_non_optimized(
|
||||
dense_fsa_vec: k2.DenseFsaVec,
|
||||
texts: List[str],
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
output_beam: float,
|
||||
den_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -112,9 +114,12 @@ def _compute_mmi_loss_exact_non_optimized(
|
||||
"""
|
||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
|
||||
|
||||
# TODO: pass output_beam as function argument
|
||||
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
|
||||
den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=8.0)
|
||||
num_lats = k2.intersect_dense(
|
||||
num_graphs, dense_fsa_vec, output_beam=output_beam
|
||||
)
|
||||
den_lats = k2.intersect_dense(
|
||||
den_graphs, dense_fsa_vec, output_beam=output_beam
|
||||
)
|
||||
|
||||
num_tot_scores = num_lats.get_tot_scores(
|
||||
log_semiring=True, use_double_scores=True
|
||||
@ -134,6 +139,7 @@ def _compute_mmi_loss_pruned(
|
||||
dense_fsa_vec: k2.DenseFsaVec,
|
||||
texts: List[str],
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
output_beam: float,
|
||||
den_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -148,7 +154,9 @@ def _compute_mmi_loss_pruned(
|
||||
"""
|
||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
|
||||
|
||||
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
|
||||
num_lats = k2.intersect_dense(
|
||||
num_graphs, dense_fsa_vec, output_beam=output_beam
|
||||
)
|
||||
|
||||
# the values for search_beam/output_beam/min_active_states/max_active_states
|
||||
# are not tuned. You may want to tune them.
|
||||
@ -156,7 +164,7 @@ def _compute_mmi_loss_pruned(
|
||||
den_graphs,
|
||||
dense_fsa_vec,
|
||||
search_beam=20.0,
|
||||
output_beam=8.0,
|
||||
output_beam=output_beam,
|
||||
min_active_states=30,
|
||||
max_active_states=10000,
|
||||
)
|
||||
@ -185,13 +193,15 @@ class LFMMILoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
output_beam: float,
|
||||
use_pruned_intersect: bool = False,
|
||||
den_scale: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.graph_compiler = graph_compiler
|
||||
self.den_scale = den_scale
|
||||
self.output_beam = output_beam
|
||||
self.use_pruned_intersect = use_pruned_intersect
|
||||
self.den_scale = den_scale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -218,5 +228,6 @@ class LFMMILoss(nn.Module):
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
texts=texts,
|
||||
graph_compiler=self.graph_compiler,
|
||||
output_beam=self.output_beam,
|
||||
den_scale=self.den_scale,
|
||||
)
|
||||
|
@ -155,6 +155,12 @@ def mmi_graph_compiler_test():
|
||||
)
|
||||
den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo")
|
||||
|
||||
texts = ["cat at cat", "ac at ca"]
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
# c a t a t c a t a c a t SPN
|
||||
expected_token_ids = [[3, 2, 4, 2, 4, 3, 2, 4], [2, 3, 2, 4, 1]]
|
||||
assert token_ids == expected_token_ids
|
||||
|
||||
|
||||
def test_main():
|
||||
generate_test_data()
|
||||
|
Loading…
x
Reference in New Issue
Block a user