Fix comments

This commit is contained in:
pkufool 2022-03-18 10:45:47 +08:00
parent dbfe8fbb1a
commit 089b8178f0
2 changed files with 23 additions and 15 deletions

View File

@ -26,23 +26,26 @@ from icefall.utils import get_texts
def fast_beam_search(
decoding_graph: k2.Fsa,
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[int]:
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
@ -66,15 +69,17 @@ def fast_beam_search(
max_contexts=max_contexts,
max_states=max_states,
)
indivisual_streams = []
individual_streams = []
for i in range(B):
indivisual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(indivisual_streams, config)
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
@ -90,7 +95,7 @@ def fast_beam_search(
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_atreams()
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice)

View File

@ -60,7 +60,7 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
@ -135,16 +135,19 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
@ -185,8 +188,8 @@ def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: k2.Fsa,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -293,7 +296,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: k2.Fsa,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.