mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-17 13:12:19 +00:00
minor updates
This commit is contained in:
parent
975f6e03fd
commit
ad7aeaccb0
@ -31,6 +31,7 @@ class DecodeStream(object):
|
||||
params: AttributeDict,
|
||||
cut_id: str,
|
||||
initial_states: List[torch.Tensor],
|
||||
decode_state: k2.DecodeStateInfo,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> None:
|
||||
"""
|
||||
@ -50,6 +51,7 @@ class DecodeStream(object):
|
||||
self.LOG_EPS = math.log(1e-10)
|
||||
|
||||
self.states = initial_states
|
||||
self.decode_state = decode_state
|
||||
|
||||
# It contains a 2-D tensors representing the feature frames.
|
||||
self.features: torch.Tensor = None
|
||||
|
@ -201,6 +201,7 @@ def get_decoding_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"frame_shift_ms": 10,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
@ -216,7 +217,9 @@ def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
H: Optional[k2.Fsa],
|
||||
intersector: k2.OnlineDenseIntersecter,
|
||||
decode_streams: List[DecodeStream],
|
||||
streams_to_pad: int = None,
|
||||
) -> List[int]:
|
||||
"""Decode one chunk frames of features for each decode_streams and
|
||||
return the indexes of finished streams in a List.
|
||||
@ -276,39 +279,49 @@ def decode_one_chunk(
|
||||
)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
# supervisions["sequence_idx"],
|
||||
torch.tensor([index for index, _ in enumerate(decode_streams)]),
|
||||
torch.div(
|
||||
0,
|
||||
params.subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
torch.div(
|
||||
feature_lens,
|
||||
params.subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
if streams_to_pad:
|
||||
ctc_output = torch.cat(
|
||||
[
|
||||
ctc_output,
|
||||
torch.zeros(
|
||||
(streams_to_pad, ctc_output.size(-2), ctc_output.size(-1)),
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
supervision_segments = torch.tensor(
|
||||
[[i, 0, 8] for i in range(params.num_decode_streams)],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# decoding_graph = H
|
||||
|
||||
# lattice = get_lattice(
|
||||
# nnet_output=ctc_output,
|
||||
# decoding_graph=decoding_graph,
|
||||
# supervision_segments=supervision_segments,
|
||||
# search_beam=params.search_beam,
|
||||
# output_beam=params.output_beam,
|
||||
# min_active_states=params.min_active_states,
|
||||
# max_active_states=params.max_active_states,
|
||||
# subsampling_factor=params.subsampling_factor,
|
||||
# )
|
||||
dense_fsa_vec = k2.DenseFsaVec(ctc_output, supervision_segments)
|
||||
|
||||
current_decode_states = [
|
||||
decode_stream.decode_state for decode_stream in decode_streams
|
||||
]
|
||||
if streams_to_pad:
|
||||
current_decode_states += [k2.DecodeStateInfo()] * streams_to_pad
|
||||
lattice, current_decode_states = intersector.decode(
|
||||
dense_fsa_vec, current_decode_states
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
@ -317,10 +330,15 @@ def decode_one_chunk(
|
||||
|
||||
states = unstack_states(new_states)
|
||||
|
||||
num_streams = (
|
||||
len(decode_streams) - streams_to_pad if streams_to_pad else len(decode_streams)
|
||||
)
|
||||
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
for i in range(num_streams):
|
||||
decode_streams[i].hyp += token_ids[i]
|
||||
decode_streams[i].states = states[i]
|
||||
decode_streams[i].decode_state = current_decode_states[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
@ -367,6 +385,15 @@ def decode_dataset(
|
||||
|
||||
log_interval = 100
|
||||
|
||||
intersector = k2.OnlineDenseIntersecter(
|
||||
decoding_graph=decoding_graph,
|
||||
num_streams=params.num_decode_streams,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
)
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
@ -377,6 +404,7 @@ def decode_dataset(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
initial_states=initial_states,
|
||||
decode_state=k2.DecodeStateInfo(),
|
||||
device=device,
|
||||
)
|
||||
|
||||
@ -400,7 +428,11 @@ def decode_dataset(
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
params=params,
|
||||
model=model,
|
||||
H=decoding_graph,
|
||||
intersector=intersector,
|
||||
decode_streams=decode_streams,
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
@ -415,10 +447,16 @@ def decode_dataset(
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
num_remained_decode_streams = len(decode_streams)
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
while num_remained_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, H=decoding_graph, decode_streams=decode_streams
|
||||
params=params,
|
||||
model=model,
|
||||
H=decoding_graph,
|
||||
intersector=intersector,
|
||||
decode_streams=decode_streams,
|
||||
streams_to_pad=params.num_decode_streams - num_remained_decode_streams,
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
@ -429,6 +467,7 @@ def decode_dataset(
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
num_remained_decode_streams -= 1
|
||||
|
||||
key = "ctc-decoding"
|
||||
return {key: decode_results}
|
||||
@ -624,6 +663,7 @@ def main():
|
||||
modified=True,
|
||||
device=device,
|
||||
)
|
||||
H = k2.Fsa.from_fsas([H])
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user