mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Fix comments
This commit is contained in:
parent
dbfe8fbb1a
commit
089b8178f0
@ -26,23 +26,26 @@ from icefall.utils import get_texts
|
|||||||
|
|
||||||
|
|
||||||
def fast_beam_search(
|
def fast_beam_search(
|
||||||
decoding_graph: k2.Fsa,
|
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: float,
|
beam: float,
|
||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
) -> List[int]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
encoder_out:
|
|
||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
beam:
|
beam:
|
||||||
Beam value, similar to the beam used in Kaldi..
|
Beam value, similar to the beam used in Kaldi..
|
||||||
max_states:
|
max_states:
|
||||||
@ -66,15 +69,17 @@ def fast_beam_search(
|
|||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
)
|
)
|
||||||
indivisual_streams = []
|
individual_streams = []
|
||||||
for i in range(B):
|
for i in range(B):
|
||||||
indivisual_streams.append(k2.RnntDecodingStream(decoding_graph))
|
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
|
||||||
decoding_streams = k2.RnntDecodingStreams(indivisual_streams, config)
|
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||||
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# shape is a RaggedShape of shape (B, context)
|
# shape is a RaggedShape of shape (B, context)
|
||||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||||
shape, contexts = decoding_streams.get_contexts()
|
shape, contexts = decoding_streams.get_contexts()
|
||||||
|
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||||
|
contexts = contexts.to(torch.int64)
|
||||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||||
decoder_out = model.decoder(contexts, need_pad=False)
|
decoder_out = model.decoder(contexts, need_pad=False)
|
||||||
# current_encoder_out is of shape
|
# current_encoder_out is of shape
|
||||||
@ -90,7 +95,7 @@ def fast_beam_search(
|
|||||||
logits = logits.squeeze(1).squeeze(1)
|
logits = logits.squeeze(1).squeeze(1)
|
||||||
log_probs = logits.log_softmax(dim=-1)
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
decoding_streams.advance(log_probs)
|
decoding_streams.advance(log_probs)
|
||||||
decoding_streams.terminate_and_flush_to_atreams()
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
best_path = one_best_decoding(lattice)
|
||||||
|
@ -60,7 +60,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -135,16 +135,19 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
beam_search or modified_beam_search""",
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam",
|
"--beam",
|
||||||
type=float,
|
type=float,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
fast_beam_search""",
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -185,8 +188,8 @@ def decode_one_batch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
decoding_graph: k2.Fsa,
|
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -293,7 +296,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
decoding_graph: k2.Fsa,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user