icefall/icefall/mmi.py
Zengwei Yao b25c234c51
Add Zipformer-MMI (#746)
* Minor fix to conformer-mmi

* Minor fixes

* Fix decode.py

* add training files

* train with ctc warmup

* add pruned_transducer_stateless7_mmi

* add zipformer_mmi/mmi_decode.py, using HP as decoding graph

* add mmi_decode.py

* remove pruned_transducer_stateless7_mmi

* rename zipformer_mmi/train_with_ctc.py as zipformer_mmi/train.py

* remove unused method

* rename mmi_decode.py

* add export.py pretrained.py jit_pretrained.py ...

* add RESULTS.md

* add CI test

* add docs

* add README.md

Co-authored-by: pkufool <wkang.pku@gmail.com>
2022-12-11 21:30:39 +08:00

222 lines
6.5 KiB
Python

from typing import List
import k2
import torch
from torch import nn
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
def _compute_mmi_loss_exact_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
beam_size: float = 8.0,
) -> torch.Tensor:
"""
The function name contains `exact`, which means it uses a version of
intersection without pruning.
`optimized` in the function name means this function is optimized
in that it calls k2.intersect_dense only once
Note:
It is faster at the cost of using more memory.
Args:
dense_fsa_vec:
It contains the neural network output.
texts:
The transcript. Each element consists of space(s) separated words.
graph_compiler:
Used to build num_graphs and den_graphs
den_scale:
The scale applied to the denominator tot_scores.
Returns:
Return a scalar loss. It is the sum over utterances in a batch,
without normalization.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
device = num_graphs.device
num_fsas = num_graphs.shape[0]
assert dense_fsa_vec.dim0() == num_fsas
assert den_graphs.shape[0] == 1
# The motivation to concatenate num_graphs and den_graphs
# is to reduce the number of calls to k2.intersect_dense.
num_den_graphs = k2.cat([num_graphs, den_graphs])
# NOTE: The a_to_b_map in k2.intersect_dense must be sorted
# so the following reorders num_den_graphs.
#
# The following code computes a_to_b_map
# [0, 1, 2, ... ]
num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)
# [num_fsas, num_fsas, num_fsas, ... ]
den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)
# [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
num_den_graphs_indexes = (
torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)
)
num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
# [[0, 1, 2, ...]]
a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)
# [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)
num_den_lats = k2.intersect_dense(
num_den_reordered_graphs,
dense_fsa_vec,
output_beam=beam_size,
a_to_b_map=a_to_b_map,
)
num_den_tot_scores = num_den_lats.get_tot_scores(
log_semiring=True, use_double_scores=True
)
num_tot_scores = num_den_tot_scores[::2]
den_tot_scores = num_den_tot_scores[1::2]
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
def _compute_mmi_loss_exact_non_optimized(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
beam_size: float = 8.0,
) -> torch.Tensor:
"""
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
of the arguments.
It's more readable, though it invokes k2.intersect_dense twice.
Note:
It uses less memory at the cost of speed. It is slower.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
# TODO: pass output_beam as function argument
num_lats = k2.intersect_dense(
num_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
)
den_lats = k2.intersect_dense(
den_graphs, dense_fsa_vec, output_beam=beam_size, max_arcs=2147483600
)
num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
def _compute_mmi_loss_pruned(
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
graph_compiler: MmiTrainingGraphCompiler,
den_scale: float = 1.0,
beam_size: float = 8.0,
) -> torch.Tensor:
"""
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
of the arguments.
`pruned` means it uses k2.intersect_dense_pruned
Note:
It uses the least amount of memory, but the loss is not exact due
to pruning.
"""
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=8.0)
# the values for search_beam/output_beam/min_active_states/max_active_states
# are not tuned. You may want to tune them.
den_lats = k2.intersect_dense_pruned(
den_graphs,
dense_fsa_vec,
search_beam=20.0,
output_beam=beam_size,
min_active_states=30,
max_active_states=10000,
)
num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
tot_scores = num_tot_scores - den_scale * den_tot_scores
loss = -1 * tot_scores.sum()
return loss
class LFMMILoss(nn.Module):
"""
Computes Lattice-Free Maximum Mutual Information (LFMMI) loss.
TODO: more detailed description
"""
def __init__(
self,
graph_compiler: MmiTrainingGraphCompiler,
use_pruned_intersect: bool = False,
den_scale: float = 1.0,
beam_size: float = 8.0,
):
super().__init__()
self.graph_compiler = graph_compiler
self.den_scale = den_scale
self.use_pruned_intersect = use_pruned_intersect
self.beam_size = beam_size
def forward(
self,
dense_fsa_vec: k2.DenseFsaVec,
texts: List[str],
) -> torch.Tensor:
"""
Args:
dense_fsa_vec:
It contains the neural network output.
texts:
A list of strings. Each string contains space(s) separated words.
Returns:
Return a scalar loss. It is the sum over utterances in a batch,
without normalization.
"""
if self.use_pruned_intersect:
func = _compute_mmi_loss_pruned
else:
func = _compute_mmi_loss_exact_non_optimized
# func = _compute_mmi_loss_exact_optimized
return func(
dense_fsa_vec=dense_fsa_vec,
texts=texts,
graph_compiler=self.graph_compiler,
den_scale=self.den_scale,
beam_size=self.beam_size,
)