diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py index 9d86b507a..aa031d930 100755 --- a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py @@ -45,7 +45,13 @@ from asr_datamodule import LibriSpeechAsrDataModule from ctc_decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet -from torch import Tensor, nn +from streaming_decode import ( + get_init_states, + stack_states, + streaming_forward, + unstack_states, +) +from torch import nn from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params @@ -55,18 +61,10 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) +from icefall.decode import get_lattice, one_best_decoding from icefall.utils import ( AttributeDict, get_texts, - make_pad_mask, setup_logger, store_transcripts, str2bool, @@ -198,234 +196,6 @@ def get_parser(): return parser -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -493,7 +263,7 @@ def decode_one_chunk( supervision_segments = torch.stack( ( # supervisions["sequence_idx"], - list(map(lambda x: x.cut_id, decode_streams)), + torch.tensor([index for index, _ in enumerate(decode_streams)]), torch.div( 0, params.subsampling_factor,