minor updates

This commit is contained in:
JinZr 2023-12-07 18:34:25 +08:00
parent 975f6e03fd
commit ad7aeaccb0
2 changed files with 74 additions and 32 deletions

View File

@ -31,6 +31,7 @@ class DecodeStream(object):
params: AttributeDict,
cut_id: str,
initial_states: List[torch.Tensor],
decode_state: k2.DecodeStateInfo,
device: torch.device = torch.device("cpu"),
) -> None:
"""
@ -50,6 +51,7 @@ class DecodeStream(object):
self.LOG_EPS = math.log(1e-10)
self.states = initial_states
self.decode_state = decode_state
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None

View File

@ -201,6 +201,7 @@ def get_decoding_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"subsampling_factor": 4,
"frame_shift_ms": 10,
"search_beam": 20,
"output_beam": 8,
@ -216,7 +217,9 @@ def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
H: Optional[k2.Fsa],
intersector: k2.OnlineDenseIntersecter,
decode_streams: List[DecodeStream],
streams_to_pad: int = None,
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
@ -276,39 +279,49 @@ def decode_one_chunk(
)
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
supervision_segments = torch.stack(
(
# supervisions["sequence_idx"],
torch.tensor([index for index, _ in enumerate(decode_streams)]),
torch.div(
0,
params.subsampling_factor,
rounding_mode="floor",
if streams_to_pad:
ctc_output = torch.cat(
[
ctc_output,
torch.zeros(
(streams_to_pad, ctc_output.size(-2), ctc_output.size(-1)),
device=device,
),
torch.div(
feature_lens,
params.subsampling_factor,
rounding_mode="floor",
),
),
1,
).to(torch.int32)
decoding_graph = H
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
]
)
supervision_segments = torch.tensor(
[[i, 0, 8] for i in range(params.num_decode_streams)],
dtype=torch.int32,
)
# decoding_graph = H
# lattice = get_lattice(
# nnet_output=ctc_output,
# decoding_graph=decoding_graph,
# supervision_segments=supervision_segments,
# search_beam=params.search_beam,
# output_beam=params.output_beam,
# min_active_states=params.min_active_states,
# max_active_states=params.max_active_states,
# subsampling_factor=params.subsampling_factor,
# )
dense_fsa_vec = k2.DenseFsaVec(ctc_output, supervision_segments)
current_decode_states = [
decode_stream.decode_state for decode_stream in decode_streams
]
if streams_to_pad:
current_decode_states += [k2.DecodeStateInfo()] * streams_to_pad
lattice, current_decode_states = intersector.decode(
dense_fsa_vec, current_decode_states
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
@ -317,10 +330,15 @@ def decode_one_chunk(
states = unstack_states(new_states)
num_streams = (
len(decode_streams) - streams_to_pad if streams_to_pad else len(decode_streams)
)
finished_streams = []
for i in range(len(decode_streams)):
for i in range(num_streams):
decode_streams[i].hyp += token_ids[i]
decode_streams[i].states = states[i]
decode_streams[i].decode_state = current_decode_states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
@ -367,6 +385,15 @@ def decode_dataset(
log_interval = 100
intersector = k2.OnlineDenseIntersecter(
decoding_graph=decoding_graph,
num_streams=params.num_decode_streams,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
decode_results = []
# Contain decode streams currently running.
decode_streams = []
@ -377,6 +404,7 @@ def decode_dataset(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decode_state=k2.DecodeStateInfo(),
device=device,
)
@ -400,7 +428,11 @@ def decode_dataset(
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
params=params,
model=model,
H=decoding_graph,
intersector=intersector,
decode_streams=decode_streams,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
@ -415,10 +447,16 @@ def decode_dataset(
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
num_remained_decode_streams = len(decode_streams)
# decode final chunks of last sequences
while len(decode_streams):
while num_remained_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, H=decoding_graph, decode_streams=decode_streams
params=params,
model=model,
H=decoding_graph,
intersector=intersector,
decode_streams=decode_streams,
streams_to_pad=params.num_decode_streams - num_remained_decode_streams,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
@ -429,6 +467,7 @@ def decode_dataset(
)
)
del decode_streams[i]
num_remained_decode_streams -= 1
key = "ctc-decoding"
return {key: decode_results}
@ -624,6 +663,7 @@ def main():
modified=True,
device=device,
)
H = k2.Fsa.from_fsas([H])
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")