Merge 7c5249fb88dc39ff56b1a6c525ca191fe8a7154a into bc284e88e6459423b57fdef80ce4a8aed6122dcc

This commit is contained in:
Fangjun Kuang 2022-05-11 13:56:17 +08:00 committed by GitHub
commit 7fc2dd1e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 304 additions and 23 deletions

View File

@ -22,6 +22,162 @@ import k2
import torch
from model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
use_max: bool = False,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or an HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts per stream per frame.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
tmp_len = torch.tensor([1])
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(encoder_out[:, t:t + 1, :], 0,
shape.row_ids(1))
# fmt: on
logits = model.joiner(
current_encoder_out,
decoder_out,
tmp_len.expand(decoder_out.size(0)),
tmp_len.expand(decoder_out.size(0)),
)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
# Note: lattice is actually is an FSA if the graph is a k2.trivial_graph()
if use_max:
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
else:
num_paths = 200
use_double_scores = True
nbest_scale = 0.8
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
# remove the state axis: [fsa][state][arc] -> [fsa][arc]
word_fsa_shape = word_fsa.arcs.shape().remove_axis(1)
num_arcs = (
word_fsa_shape.row_splits(1)[1:] - word_fsa_shape.row_splits(1)[:-1]
)
num_tokens_per_path = num_arcs - 1 # minus one due to the final arc
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
word_fsa
)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores, log_semiring=True
)
tot_scores = tot_scores / num_tokens_per_path
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
@ -77,7 +233,7 @@ def greedy_search(
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t + 1, :]
# fmt: on
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
@ -204,7 +360,7 @@ class HypothesisList(object):
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
def add(self, hyp: Hypothesis, use_max: bool = False) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
@ -213,13 +369,20 @@ class HypothesisList(object):
Args:
hyp:
The hypothesis to be added.
use_max:
True to select the hypothesis with the larger log_prob in case there
already exists a hypothesis whose `ys` equals to `hyp.ys`.
False to use log_add.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
if use_max:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
else:
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else:
self._data[key] = hyp
@ -416,6 +579,7 @@ def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
@ -426,6 +590,10 @@ def modified_beam_search(
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
@ -444,7 +612,8 @@ def modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
),
use_max=use_max,
)
encoder_out_len = torch.tensor([1])
@ -522,7 +691,7 @@ def modified_beam_search(
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
B[i].add(new_hyp, use_max=use_max)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
@ -534,6 +703,7 @@ def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
@ -548,6 +718,10 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns:
Return the decoded result.
"""
@ -568,7 +742,8 @@ def _deprecated_modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
),
use_max=use_max,
)
encoder_out_len = torch.tensor([1])
@ -576,7 +751,7 @@ def _deprecated_modified_beam_search(
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t + 1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
@ -629,7 +804,7 @@ def _deprecated_modified_beam_search(
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
B.add(new_hyp, use_max=use_max)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
@ -641,6 +816,7 @@ def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
use_max: bool = False,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -654,6 +830,10 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
use_max:
If True, it uses max operation to select the hypothesis with the
larger log_prob in case two hypotheses have the same token sequences.
If False, use log add.
Returns:
Return the decoded result.
"""
@ -680,7 +860,8 @@ def beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
),
use_max=use_max,
)
max_sym_per_utt = 20000
@ -694,7 +875,7 @@ def beam_search(
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
current_encoder_out = encoder_out[:, t:t + 1, :]
# fmt: on
A = B
B = HypothesisList()
@ -726,7 +907,10 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
B.add(
Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob),
use_max=use_max,
)
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
@ -738,7 +922,10 @@ def beam_search(
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
A.add(
Hypothesis(ys=new_ys, log_prob=new_log_prob),
use_max=use_max,
)
# Check whether B contains more than "beam" elements more probable
# than the most probable in A

View File

@ -22,7 +22,8 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--max-duration 1000 \
--max-sym-per-frame 1 \
--decoding-method greedy_search
(2) beam search
@ -30,7 +31,7 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--max-duration 1000 \
--decoding-method beam_search \
--beam-size 4
@ -39,9 +40,20 @@ Usage:
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 100 \
--max-duration 1000 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_stateless/decode.py \
--epoch 14 \
--avg 7 \
--exp-dir ./transducer_stateless/exp \
--max-duration 1000 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
@ -49,14 +61,16 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -68,6 +82,7 @@ from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -115,6 +130,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -123,7 +139,36 @@ def get_parser():
type=int,
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
beam_search or modified_beam_search.
It specifies the number of active hypotheses to keep at each
time step.
""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
@ -141,6 +186,17 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--use-max",
type=str2bool,
default=False,
help="""If True, use max-op to select the hypothesis that have the
max log_prob in case of duplicate hypotheses.
If False, use log_add.
Used only for beam_search, modified_beam_search, and fast_beam_search
""",
)
return parser
@ -149,6 +205,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -190,7 +247,18 @@ def decode_one_batch(
)
hyp_list: List[List[int]] = []
if (
if params.decoding_method == "fast_beam_search":
hyp_list = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
use_max=params.use_max,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
@ -203,6 +271,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
use_max=params.use_max,
)
else:
batch_size = encoder_out.size(0)
@ -221,6 +290,7 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
use_max=params.use_max,
)
else:
raise ValueError(
@ -232,6 +302,14 @@ def decode_one_batch(
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
@ -241,6 +319,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -281,6 +360,7 @@ def decode_dataset(
model=model,
sp=sp,
batch=batch,
decoding_graph=decoding_graph,
)
for name, hyps in hyps_dict.items():
@ -358,15 +438,22 @@ def main():
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"greedy_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
if "fast_beam_search" == params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-use-max-{params.use_max}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-use-max-{params.use_max}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -408,6 +495,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +520,7 @@ def main():
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
padding_idx=blank_id,
)
self.blank_id = blank_id
self.vocab_size = vocab_size
assert context_size >= 1, context_size
self.context_size = context_size

View File

@ -218,7 +218,7 @@ class Nbest(object):
# word_seq is a k2.RaggedTensor sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
# The last entry in each sublist is -1.
# It axes is [utt][path][word_id]
# Its axes are [utt][path][word_id]
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else: