diff --git a/egs/librispeech/ASR/zipformer/ctc_decode_stream.py b/egs/librispeech/ASR/zipformer/ctc_decode_stream.py index ed83d02a7..066f8afc3 100644 --- a/egs/librispeech/ASR/zipformer/ctc_decode_stream.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode_stream.py @@ -31,6 +31,7 @@ class DecodeStream(object): params: AttributeDict, cut_id: str, initial_states: List[torch.Tensor], + decode_state: k2.DecodeStateInfo, device: torch.device = torch.device("cpu"), ) -> None: """ @@ -50,6 +51,7 @@ class DecodeStream(object): self.LOG_EPS = math.log(1e-10) self.states = initial_states + self.decode_state = decode_state # It contains a 2-D tensors representing the feature frames. self.features: torch.Tensor = None diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py index ab550290c..52a5162fb 100755 --- a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py @@ -201,6 +201,7 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "feature_dim": 80, + "subsampling_factor": 4, "frame_shift_ms": 10, "search_beam": 20, "output_beam": 8, @@ -216,7 +217,9 @@ def decode_one_chunk( params: AttributeDict, model: nn.Module, H: Optional[k2.Fsa], + intersector: k2.OnlineDenseIntersecter, decode_streams: List[DecodeStream], + streams_to_pad: int = None, ) -> List[int]: """Decode one chunk frames of features for each decode_streams and return the indexes of finished streams in a List. @@ -276,39 +279,49 @@ def decode_one_chunk( ) ctc_output = model.ctc_output(encoder_out) # (N, T, C) - supervision_segments = torch.stack( - ( - # supervisions["sequence_idx"], - torch.tensor([index for index, _ in enumerate(decode_streams)]), - torch.div( - 0, - params.subsampling_factor, - rounding_mode="floor", - ), - torch.div( - feature_lens, - params.subsampling_factor, - rounding_mode="floor", - ), - ), - 1, - ).to(torch.int32) + if streams_to_pad: + ctc_output = torch.cat( + [ + ctc_output, + torch.zeros( + (streams_to_pad, ctc_output.size(-2), ctc_output.size(-1)), + device=device, + ), + ] + ) - decoding_graph = H - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, + supervision_segments = torch.tensor( + [[i, 0, 8] for i in range(params.num_decode_streams)], + dtype=torch.int32, ) + + # decoding_graph = H + + # lattice = get_lattice( + # nnet_output=ctc_output, + # decoding_graph=decoding_graph, + # supervision_segments=supervision_segments, + # search_beam=params.search_beam, + # output_beam=params.output_beam, + # min_active_states=params.min_active_states, + # max_active_states=params.max_active_states, + # subsampling_factor=params.subsampling_factor, + # ) + dense_fsa_vec = k2.DenseFsaVec(ctc_output, supervision_segments) + + current_decode_states = [ + decode_stream.decode_state for decode_stream in decode_streams + ] + if streams_to_pad: + current_decode_states += [k2.DecodeStateInfo()] * streams_to_pad + lattice, current_decode_states = intersector.decode( + dense_fsa_vec, current_decode_states + ) + best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs # since we are using H, not HLG here. # @@ -317,10 +330,15 @@ def decode_one_chunk( states = unstack_states(new_states) + num_streams = ( + len(decode_streams) - streams_to_pad if streams_to_pad else len(decode_streams) + ) + finished_streams = [] - for i in range(len(decode_streams)): + for i in range(num_streams): decode_streams[i].hyp += token_ids[i] decode_streams[i].states = states[i] + decode_streams[i].decode_state = current_decode_states[i] decode_streams[i].done_frames += encoder_out_lens[i] if decode_streams[i].done: finished_streams.append(i) @@ -367,6 +385,15 @@ def decode_dataset( log_interval = 100 + intersector = k2.OnlineDenseIntersecter( + decoding_graph=decoding_graph, + num_streams=params.num_decode_streams, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + decode_results = [] # Contain decode streams currently running. decode_streams = [] @@ -377,6 +404,7 @@ def decode_dataset( params=params, cut_id=cut.id, initial_states=initial_states, + decode_state=k2.DecodeStateInfo(), device=device, ) @@ -400,7 +428,11 @@ def decode_dataset( while len(decode_streams) >= params.num_decode_streams: finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams + params=params, + model=model, + H=decoding_graph, + intersector=intersector, + decode_streams=decode_streams, ) for i in sorted(finished_streams, reverse=True): decode_results.append( @@ -415,10 +447,16 @@ def decode_dataset( if num % log_interval == 0: logging.info(f"Cuts processed until now is {num}.") + num_remained_decode_streams = len(decode_streams) # decode final chunks of last sequences - while len(decode_streams): + while num_remained_decode_streams: finished_streams = decode_one_chunk( - params=params, model=model, H=decoding_graph, decode_streams=decode_streams + params=params, + model=model, + H=decoding_graph, + intersector=intersector, + decode_streams=decode_streams, + streams_to_pad=params.num_decode_streams - num_remained_decode_streams, ) for i in sorted(finished_streams, reverse=True): decode_results.append( @@ -429,6 +467,7 @@ def decode_dataset( ) ) del decode_streams[i] + num_remained_decode_streams -= 1 key = "ctc-decoding" return {key: decode_results} @@ -624,6 +663,7 @@ def main(): modified=True, device=device, ) + H = k2.Fsa.from_fsas([H]) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")