mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-17 21:22:20 +00:00
minor updates
This commit is contained in:
parent
975f6e03fd
commit
ad7aeaccb0
@ -31,6 +31,7 @@ class DecodeStream(object):
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
cut_id: str,
|
cut_id: str,
|
||||||
initial_states: List[torch.Tensor],
|
initial_states: List[torch.Tensor],
|
||||||
|
decode_state: k2.DecodeStateInfo,
|
||||||
device: torch.device = torch.device("cpu"),
|
device: torch.device = torch.device("cpu"),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -50,6 +51,7 @@ class DecodeStream(object):
|
|||||||
self.LOG_EPS = math.log(1e-10)
|
self.LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
self.states = initial_states
|
self.states = initial_states
|
||||||
|
self.decode_state = decode_state
|
||||||
|
|
||||||
# It contains a 2-D tensors representing the feature frames.
|
# It contains a 2-D tensors representing the feature frames.
|
||||||
self.features: torch.Tensor = None
|
self.features: torch.Tensor = None
|
||||||
|
@ -201,6 +201,7 @@ def get_decoding_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
|
"subsampling_factor": 4,
|
||||||
"frame_shift_ms": 10,
|
"frame_shift_ms": 10,
|
||||||
"search_beam": 20,
|
"search_beam": 20,
|
||||||
"output_beam": 8,
|
"output_beam": 8,
|
||||||
@ -216,7 +217,9 @@ def decode_one_chunk(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
H: Optional[k2.Fsa],
|
H: Optional[k2.Fsa],
|
||||||
|
intersector: k2.OnlineDenseIntersecter,
|
||||||
decode_streams: List[DecodeStream],
|
decode_streams: List[DecodeStream],
|
||||||
|
streams_to_pad: int = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""Decode one chunk frames of features for each decode_streams and
|
"""Decode one chunk frames of features for each decode_streams and
|
||||||
return the indexes of finished streams in a List.
|
return the indexes of finished streams in a List.
|
||||||
@ -276,39 +279,49 @@ def decode_one_chunk(
|
|||||||
)
|
)
|
||||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
if streams_to_pad:
|
||||||
(
|
ctc_output = torch.cat(
|
||||||
# supervisions["sequence_idx"],
|
[
|
||||||
torch.tensor([index for index, _ in enumerate(decode_streams)]),
|
ctc_output,
|
||||||
torch.div(
|
torch.zeros(
|
||||||
0,
|
(streams_to_pad, ctc_output.size(-2), ctc_output.size(-1)),
|
||||||
params.subsampling_factor,
|
device=device,
|
||||||
rounding_mode="floor",
|
),
|
||||||
),
|
]
|
||||||
torch.div(
|
)
|
||||||
feature_lens,
|
|
||||||
params.subsampling_factor,
|
|
||||||
rounding_mode="floor",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
1,
|
|
||||||
).to(torch.int32)
|
|
||||||
|
|
||||||
decoding_graph = H
|
supervision_segments = torch.tensor(
|
||||||
|
[[i, 0, 8] for i in range(params.num_decode_streams)],
|
||||||
lattice = get_lattice(
|
dtype=torch.int32,
|
||||||
nnet_output=ctc_output,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
supervision_segments=supervision_segments,
|
|
||||||
search_beam=params.search_beam,
|
|
||||||
output_beam=params.output_beam,
|
|
||||||
min_active_states=params.min_active_states,
|
|
||||||
max_active_states=params.max_active_states,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# decoding_graph = H
|
||||||
|
|
||||||
|
# lattice = get_lattice(
|
||||||
|
# nnet_output=ctc_output,
|
||||||
|
# decoding_graph=decoding_graph,
|
||||||
|
# supervision_segments=supervision_segments,
|
||||||
|
# search_beam=params.search_beam,
|
||||||
|
# output_beam=params.output_beam,
|
||||||
|
# min_active_states=params.min_active_states,
|
||||||
|
# max_active_states=params.max_active_states,
|
||||||
|
# subsampling_factor=params.subsampling_factor,
|
||||||
|
# )
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(ctc_output, supervision_segments)
|
||||||
|
|
||||||
|
current_decode_states = [
|
||||||
|
decode_stream.decode_state for decode_stream in decode_streams
|
||||||
|
]
|
||||||
|
if streams_to_pad:
|
||||||
|
current_decode_states += [k2.DecodeStateInfo()] * streams_to_pad
|
||||||
|
lattice, current_decode_states = intersector.decode(
|
||||||
|
dense_fsa_vec, current_decode_states
|
||||||
|
)
|
||||||
|
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||||
# since we are using H, not HLG here.
|
# since we are using H, not HLG here.
|
||||||
#
|
#
|
||||||
@ -317,10 +330,15 @@ def decode_one_chunk(
|
|||||||
|
|
||||||
states = unstack_states(new_states)
|
states = unstack_states(new_states)
|
||||||
|
|
||||||
|
num_streams = (
|
||||||
|
len(decode_streams) - streams_to_pad if streams_to_pad else len(decode_streams)
|
||||||
|
)
|
||||||
|
|
||||||
finished_streams = []
|
finished_streams = []
|
||||||
for i in range(len(decode_streams)):
|
for i in range(num_streams):
|
||||||
decode_streams[i].hyp += token_ids[i]
|
decode_streams[i].hyp += token_ids[i]
|
||||||
decode_streams[i].states = states[i]
|
decode_streams[i].states = states[i]
|
||||||
|
decode_streams[i].decode_state = current_decode_states[i]
|
||||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
if decode_streams[i].done:
|
if decode_streams[i].done:
|
||||||
finished_streams.append(i)
|
finished_streams.append(i)
|
||||||
@ -367,6 +385,15 @@ def decode_dataset(
|
|||||||
|
|
||||||
log_interval = 100
|
log_interval = 100
|
||||||
|
|
||||||
|
intersector = k2.OnlineDenseIntersecter(
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
num_streams=params.num_decode_streams,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
)
|
||||||
|
|
||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
decode_streams = []
|
decode_streams = []
|
||||||
@ -377,6 +404,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
cut_id=cut.id,
|
cut_id=cut.id,
|
||||||
initial_states=initial_states,
|
initial_states=initial_states,
|
||||||
|
decode_state=k2.DecodeStateInfo(),
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -400,7 +428,11 @@ def decode_dataset(
|
|||||||
|
|
||||||
while len(decode_streams) >= params.num_decode_streams:
|
while len(decode_streams) >= params.num_decode_streams:
|
||||||
finished_streams = decode_one_chunk(
|
finished_streams = decode_one_chunk(
|
||||||
params=params, model=model, decode_streams=decode_streams
|
params=params,
|
||||||
|
model=model,
|
||||||
|
H=decoding_graph,
|
||||||
|
intersector=intersector,
|
||||||
|
decode_streams=decode_streams,
|
||||||
)
|
)
|
||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
@ -415,10 +447,16 @@ def decode_dataset(
|
|||||||
if num % log_interval == 0:
|
if num % log_interval == 0:
|
||||||
logging.info(f"Cuts processed until now is {num}.")
|
logging.info(f"Cuts processed until now is {num}.")
|
||||||
|
|
||||||
|
num_remained_decode_streams = len(decode_streams)
|
||||||
# decode final chunks of last sequences
|
# decode final chunks of last sequences
|
||||||
while len(decode_streams):
|
while num_remained_decode_streams:
|
||||||
finished_streams = decode_one_chunk(
|
finished_streams = decode_one_chunk(
|
||||||
params=params, model=model, H=decoding_graph, decode_streams=decode_streams
|
params=params,
|
||||||
|
model=model,
|
||||||
|
H=decoding_graph,
|
||||||
|
intersector=intersector,
|
||||||
|
decode_streams=decode_streams,
|
||||||
|
streams_to_pad=params.num_decode_streams - num_remained_decode_streams,
|
||||||
)
|
)
|
||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
decode_results.append(
|
||||||
@ -429,6 +467,7 @@ def decode_dataset(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
del decode_streams[i]
|
del decode_streams[i]
|
||||||
|
num_remained_decode_streams -= 1
|
||||||
|
|
||||||
key = "ctc-decoding"
|
key = "ctc-decoding"
|
||||||
return {key: decode_results}
|
return {key: decode_results}
|
||||||
@ -624,6 +663,7 @@ def main():
|
|||||||
modified=True,
|
modified=True,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
H = k2.Fsa.from_fsas([H])
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user