mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Support using attention decoder in MMI training.
This commit is contained in:
parent
78e1fdc994
commit
4f3a53fc41
@ -85,7 +85,7 @@ def get_parser():
|
|||||||
"--start-epoch",
|
"--start-epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from this epoch.
|
||||||
If it is positive, it will load checkpoint from
|
If it is positive, it will load checkpoint from
|
||||||
conformer_ctc/exp/epoch-{start_epoch-1}.pt
|
conformer_ctc/exp/epoch-{start_epoch-1}.pt
|
||||||
""",
|
""",
|
||||||
|
@ -57,7 +57,7 @@ class Conformer(Transformer):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
is_espnet_structure: bool = False,
|
is_espnet_structure: bool = False,
|
||||||
mmi_loss: bool = True,
|
is_bpe: bool = True,
|
||||||
use_feat_batchnorm: bool = False,
|
use_feat_batchnorm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
@ -72,7 +72,7 @@ class Conformer(Transformer):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
normalize_before=normalize_before,
|
normalize_before=normalize_before,
|
||||||
vgg_frontend=vgg_frontend,
|
vgg_frontend=vgg_frontend,
|
||||||
mmi_loss=mmi_loss,
|
is_bpe=is_bpe,
|
||||||
use_feat_batchnorm=use_feat_batchnorm,
|
use_feat_batchnorm=use_feat_batchnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -59,6 +75,23 @@ def get_parser():
|
|||||||
help="Should various information be logged in tensorboard.",
|
help="Should various information be logged in tensorboard.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Number of epochs to train.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-epoch",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""Resume training from this epoch.
|
||||||
|
If it is positive, it will load checkpoint from
|
||||||
|
conformer_ctc/exp/epoch-{start_epoch-1}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -88,11 +121,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
|
||||||
and continue training from that checkpoint.
|
|
||||||
|
|
||||||
- num_epochs: Number of epochs to train.
|
|
||||||
|
|
||||||
- best_train_loss: Best training loss so far. It is used to select
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
the model that has the lowest training loss. It is
|
the model that has the lowest training loss. It is
|
||||||
updated during the training.
|
updated during the training.
|
||||||
@ -120,8 +148,6 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"start_epoch": 0,
|
|
||||||
"num_epochs": 50,
|
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
@ -130,13 +156,14 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000,
|
"valid_interval": 3000,
|
||||||
|
"beam_size": 10,
|
||||||
"use_pruned_intersect": False,
|
"use_pruned_intersect": False,
|
||||||
"den_scale": 1.0,
|
"den_scale": 1.0,
|
||||||
#
|
#
|
||||||
"att_rate": 0, # If not zero, use attention decoder
|
"att_rate": 0.7, # If not zero, use attention decoder
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"num_decoder_layers": 0,
|
"num_decoder_layers": 6,
|
||||||
"is_espnet_structure": True,
|
"is_espnet_structure": True,
|
||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
"lr_factor": 5.0,
|
"lr_factor": 5.0,
|
||||||
@ -288,15 +315,14 @@ def compute_loss(
|
|||||||
|
|
||||||
loss_fn = LFMMILoss(
|
loss_fn = LFMMILoss(
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
|
output_beam=params.beam_size,
|
||||||
den_scale=params.den_scale,
|
den_scale=params.den_scale,
|
||||||
use_pruned_intersect=params.use_pruned_intersect,
|
use_pruned_intersect=params.use_pruned_intersect,
|
||||||
)
|
)
|
||||||
|
|
||||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||||
|
|
||||||
assert params.att_rate == 0
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
# TODO: not working
|
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
token_ids = graph_compiler.texts_to_ids(texts)
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
@ -304,16 +330,16 @@ def compute_loss(
|
|||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
sos_id=graph_compiler.sos_id,
|
sos_id=params.sos_id,
|
||||||
eos_id=graph_compiler.eos_id,
|
eos_id=params.eos_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
att_loss = model.decoder_forward(
|
att_loss = model.decoder_forward(
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
sos_id=graph_compiler.sos_id,
|
sos_id=params.sos_id,
|
||||||
eos_id=graph_compiler.eos_id,
|
eos_id=params.eos_id,
|
||||||
)
|
)
|
||||||
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
||||||
else:
|
else:
|
||||||
@ -587,7 +613,6 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
@ -617,12 +642,19 @@ def run(rank, world_size, args):
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
num_decoder_layers=params.num_decoder_layers,
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
vgg_frontend=False,
|
vgg_frontend=False,
|
||||||
|
is_bpe=False,
|
||||||
is_espnet_structure=params.is_espnet_structure,
|
is_espnet_structure=params.is_espnet_structure,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
)
|
)
|
||||||
|
assert model.decoder_num_class == num_classes + 1
|
||||||
|
|
||||||
|
params.sos_id = num_classes
|
||||||
|
params.eos_id = num_classes
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
@ -42,7 +42,7 @@ class Transformer(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
mmi_loss: bool = True,
|
is_bpe: bool = True,
|
||||||
use_feat_batchnorm: bool = False,
|
use_feat_batchnorm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -71,7 +71,9 @@ class Transformer(nn.Module):
|
|||||||
If True, use pre-layer norm; False to use post-layer norm.
|
If True, use pre-layer norm; False to use post-layer norm.
|
||||||
vgg_frontend:
|
vgg_frontend:
|
||||||
True to use vgg style frontend for subsampling.
|
True to use vgg style frontend for subsampling.
|
||||||
mmi_loss:
|
is_bpe:
|
||||||
|
True if the modeling unit is word pieces which has already included
|
||||||
|
SOS/EOS IDs.
|
||||||
use_feat_batchnorm:
|
use_feat_batchnorm:
|
||||||
True to use batchnorm for the input layer.
|
True to use batchnorm for the input layer.
|
||||||
"""
|
"""
|
||||||
@ -123,7 +125,7 @@ class Transformer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if num_decoder_layers > 0:
|
if num_decoder_layers > 0:
|
||||||
if mmi_loss:
|
if is_bpe is False:
|
||||||
self.decoder_num_class = (
|
self.decoder_num_class = (
|
||||||
self.num_classes + 1
|
self.num_classes + 1
|
||||||
) # +1 for the sos/eos symbol
|
) # +1 for the sos/eos symbol
|
||||||
|
@ -11,6 +11,7 @@ def _compute_mmi_loss_exact_optimized(
|
|||||||
dense_fsa_vec: k2.DenseFsaVec,
|
dense_fsa_vec: k2.DenseFsaVec,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
graph_compiler: MmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
|
output_beam: float,
|
||||||
den_scale: float = 1.0,
|
den_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -79,7 +80,7 @@ def _compute_mmi_loss_exact_optimized(
|
|||||||
num_den_lats = k2.intersect_dense(
|
num_den_lats = k2.intersect_dense(
|
||||||
num_den_reordered_graphs,
|
num_den_reordered_graphs,
|
||||||
dense_fsa_vec,
|
dense_fsa_vec,
|
||||||
output_beam=8.0,
|
output_beam=output_beam,
|
||||||
a_to_b_map=a_to_b_map,
|
a_to_b_map=a_to_b_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -99,6 +100,7 @@ def _compute_mmi_loss_exact_non_optimized(
|
|||||||
dense_fsa_vec: k2.DenseFsaVec,
|
dense_fsa_vec: k2.DenseFsaVec,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
graph_compiler: MmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
|
output_beam: float,
|
||||||
den_scale: float = 1.0,
|
den_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -112,9 +114,12 @@ def _compute_mmi_loss_exact_non_optimized(
|
|||||||
"""
|
"""
|
||||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
|
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
|
||||||
|
|
||||||
# TODO: pass output_beam as function argument
|
num_lats = k2.intersect_dense(
|
||||||
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
|
num_graphs, dense_fsa_vec, output_beam=output_beam
|
||||||
den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=8.0)
|
)
|
||||||
|
den_lats = k2.intersect_dense(
|
||||||
|
den_graphs, dense_fsa_vec, output_beam=output_beam
|
||||||
|
)
|
||||||
|
|
||||||
num_tot_scores = num_lats.get_tot_scores(
|
num_tot_scores = num_lats.get_tot_scores(
|
||||||
log_semiring=True, use_double_scores=True
|
log_semiring=True, use_double_scores=True
|
||||||
@ -134,6 +139,7 @@ def _compute_mmi_loss_pruned(
|
|||||||
dense_fsa_vec: k2.DenseFsaVec,
|
dense_fsa_vec: k2.DenseFsaVec,
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
graph_compiler: MmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
|
output_beam: float,
|
||||||
den_scale: float = 1.0,
|
den_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -148,7 +154,9 @@ def _compute_mmi_loss_pruned(
|
|||||||
"""
|
"""
|
||||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
|
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
|
||||||
|
|
||||||
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
|
num_lats = k2.intersect_dense(
|
||||||
|
num_graphs, dense_fsa_vec, output_beam=output_beam
|
||||||
|
)
|
||||||
|
|
||||||
# the values for search_beam/output_beam/min_active_states/max_active_states
|
# the values for search_beam/output_beam/min_active_states/max_active_states
|
||||||
# are not tuned. You may want to tune them.
|
# are not tuned. You may want to tune them.
|
||||||
@ -156,7 +164,7 @@ def _compute_mmi_loss_pruned(
|
|||||||
den_graphs,
|
den_graphs,
|
||||||
dense_fsa_vec,
|
dense_fsa_vec,
|
||||||
search_beam=20.0,
|
search_beam=20.0,
|
||||||
output_beam=8.0,
|
output_beam=output_beam,
|
||||||
min_active_states=30,
|
min_active_states=30,
|
||||||
max_active_states=10000,
|
max_active_states=10000,
|
||||||
)
|
)
|
||||||
@ -185,13 +193,15 @@ class LFMMILoss(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph_compiler: MmiTrainingGraphCompiler,
|
graph_compiler: MmiTrainingGraphCompiler,
|
||||||
|
output_beam: float,
|
||||||
use_pruned_intersect: bool = False,
|
use_pruned_intersect: bool = False,
|
||||||
den_scale: float = 1.0,
|
den_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.graph_compiler = graph_compiler
|
self.graph_compiler = graph_compiler
|
||||||
self.den_scale = den_scale
|
self.output_beam = output_beam
|
||||||
self.use_pruned_intersect = use_pruned_intersect
|
self.use_pruned_intersect = use_pruned_intersect
|
||||||
|
self.den_scale = den_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -218,5 +228,6 @@ class LFMMILoss(nn.Module):
|
|||||||
dense_fsa_vec=dense_fsa_vec,
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
texts=texts,
|
texts=texts,
|
||||||
graph_compiler=self.graph_compiler,
|
graph_compiler=self.graph_compiler,
|
||||||
|
output_beam=self.output_beam,
|
||||||
den_scale=self.den_scale,
|
den_scale=self.den_scale,
|
||||||
)
|
)
|
||||||
|
@ -155,6 +155,12 @@ def mmi_graph_compiler_test():
|
|||||||
)
|
)
|
||||||
den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo")
|
den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo")
|
||||||
|
|
||||||
|
texts = ["cat at cat", "ac at ca"]
|
||||||
|
token_ids = graph_compiler.texts_to_ids(texts)
|
||||||
|
# c a t a t c a t a c a t SPN
|
||||||
|
expected_token_ids = [[3, 2, 4, 2, 4, 3, 2, 4], [2, 3, 2, 4, 1]]
|
||||||
|
assert token_ids == expected_token_ids
|
||||||
|
|
||||||
|
|
||||||
def test_main():
|
def test_main():
|
||||||
generate_test_data()
|
generate_test_data()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user