mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-18 05:32:20 +00:00
Update streaming_ctc_decode.py
This commit is contained in:
parent
20f96ac6a8
commit
7fb3b13066
@ -45,7 +45,13 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
|||||||
from ctc_decode_stream import DecodeStream
|
from ctc_decode_stream import DecodeStream
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from torch import Tensor, nn
|
from streaming_decode import (
|
||||||
|
get_init_states,
|
||||||
|
stack_states,
|
||||||
|
streaming_forward,
|
||||||
|
unstack_states,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
@ -55,18 +61,10 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.decode import (
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
get_lattice,
|
|
||||||
nbest_decoding,
|
|
||||||
nbest_oracle,
|
|
||||||
one_best_decoding,
|
|
||||||
rescore_with_n_best_list,
|
|
||||||
rescore_with_whole_lattice,
|
|
||||||
)
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
make_pad_mask,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -198,234 +196,6 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_init_states(
|
|
||||||
model: nn.Module,
|
|
||||||
batch_size: int = 1,
|
|
||||||
device: torch.device = torch.device("cpu"),
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
|
||||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
|
||||||
states[-2] is the cached left padding for ConvNeXt module,
|
|
||||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
|
||||||
states[-1] is processed_lens of shape (batch,), which records the number
|
|
||||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
|
||||||
"""
|
|
||||||
states = model.encoder.get_init_states(batch_size, device)
|
|
||||||
|
|
||||||
embed_states = model.encoder_embed.get_init_states(batch_size, device)
|
|
||||||
states.append(embed_states)
|
|
||||||
|
|
||||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
|
||||||
states.append(processed_lens)
|
|
||||||
|
|
||||||
return states
|
|
||||||
|
|
||||||
|
|
||||||
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
|
||||||
"""Stack list of zipformer states that correspond to separate utterances
|
|
||||||
into a single emformer state, so that it can be used as an input for
|
|
||||||
zipformer when those utterances are formed into a batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_list:
|
|
||||||
Each element in state_list corresponding to the internal state
|
|
||||||
of the zipformer model for a single utterance. For element-n,
|
|
||||||
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
|
|
||||||
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
|
|
||||||
cached_val2, cached_conv1, cached_conv2).
|
|
||||||
state_list[n][-2] is the cached left padding for ConvNeXt module,
|
|
||||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
|
||||||
state_list[n][-1] is processed_lens of shape (batch,), which records the number
|
|
||||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
It is the inverse of :func:`unstack_states`.
|
|
||||||
"""
|
|
||||||
batch_size = len(state_list)
|
|
||||||
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
|
|
||||||
tot_num_layers = (len(state_list[0]) - 2) // 6
|
|
||||||
|
|
||||||
batch_states = []
|
|
||||||
for layer in range(tot_num_layers):
|
|
||||||
layer_offset = layer * 6
|
|
||||||
# cached_key: (left_context_len, batch_size, key_dim)
|
|
||||||
cached_key = torch.cat(
|
|
||||||
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
|
|
||||||
)
|
|
||||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
|
||||||
cached_nonlin_attn = torch.cat(
|
|
||||||
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
|
|
||||||
)
|
|
||||||
# cached_val1: (left_context_len, batch_size, value_dim)
|
|
||||||
cached_val1 = torch.cat(
|
|
||||||
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
|
|
||||||
)
|
|
||||||
# cached_val2: (left_context_len, batch_size, value_dim)
|
|
||||||
cached_val2 = torch.cat(
|
|
||||||
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
|
|
||||||
)
|
|
||||||
# cached_conv1: (#batch, channels, left_pad)
|
|
||||||
cached_conv1 = torch.cat(
|
|
||||||
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
|
|
||||||
)
|
|
||||||
# cached_conv2: (#batch, channels, left_pad)
|
|
||||||
cached_conv2 = torch.cat(
|
|
||||||
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
|
|
||||||
)
|
|
||||||
batch_states += [
|
|
||||||
cached_key,
|
|
||||||
cached_nonlin_attn,
|
|
||||||
cached_val1,
|
|
||||||
cached_val2,
|
|
||||||
cached_conv1,
|
|
||||||
cached_conv2,
|
|
||||||
]
|
|
||||||
|
|
||||||
cached_embed_left_pad = torch.cat(
|
|
||||||
[state_list[i][-2] for i in range(batch_size)], dim=0
|
|
||||||
)
|
|
||||||
batch_states.append(cached_embed_left_pad)
|
|
||||||
|
|
||||||
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
|
||||||
batch_states.append(processed_lens)
|
|
||||||
|
|
||||||
return batch_states
|
|
||||||
|
|
||||||
|
|
||||||
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
|
||||||
"""Unstack the zipformer state corresponding to a batch of utterances
|
|
||||||
into a list of states, where the i-th entry is the state from the i-th
|
|
||||||
utterance in the batch.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
It is the inverse of :func:`stack_states`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_states: A list of cached tensors of all encoder layers. For layer-i,
|
|
||||||
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
|
|
||||||
cached_conv1, cached_conv2).
|
|
||||||
state_list[-2] is the cached left padding for ConvNeXt module,
|
|
||||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
|
||||||
states[-1] is processed_lens of shape (batch,), which records the number
|
|
||||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state_list: A list of list. Each element in state_list corresponding to the internal state
|
|
||||||
of the zipformer model for a single utterance.
|
|
||||||
"""
|
|
||||||
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
|
|
||||||
tot_num_layers = (len(batch_states) - 2) // 6
|
|
||||||
|
|
||||||
processed_lens = batch_states[-1]
|
|
||||||
batch_size = processed_lens.shape[0]
|
|
||||||
|
|
||||||
state_list = [[] for _ in range(batch_size)]
|
|
||||||
|
|
||||||
for layer in range(tot_num_layers):
|
|
||||||
layer_offset = layer * 6
|
|
||||||
# cached_key: (left_context_len, batch_size, key_dim)
|
|
||||||
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
|
||||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
|
||||||
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
|
||||||
chunks=batch_size, dim=1
|
|
||||||
)
|
|
||||||
# cached_val1: (left_context_len, batch_size, value_dim)
|
|
||||||
cached_val1_list = batch_states[layer_offset + 2].chunk(
|
|
||||||
chunks=batch_size, dim=1
|
|
||||||
)
|
|
||||||
# cached_val2: (left_context_len, batch_size, value_dim)
|
|
||||||
cached_val2_list = batch_states[layer_offset + 3].chunk(
|
|
||||||
chunks=batch_size, dim=1
|
|
||||||
)
|
|
||||||
# cached_conv1: (#batch, channels, left_pad)
|
|
||||||
cached_conv1_list = batch_states[layer_offset + 4].chunk(
|
|
||||||
chunks=batch_size, dim=0
|
|
||||||
)
|
|
||||||
# cached_conv2: (#batch, channels, left_pad)
|
|
||||||
cached_conv2_list = batch_states[layer_offset + 5].chunk(
|
|
||||||
chunks=batch_size, dim=0
|
|
||||||
)
|
|
||||||
for i in range(batch_size):
|
|
||||||
state_list[i] += [
|
|
||||||
cached_key_list[i],
|
|
||||||
cached_nonlin_attn_list[i],
|
|
||||||
cached_val1_list[i],
|
|
||||||
cached_val2_list[i],
|
|
||||||
cached_conv1_list[i],
|
|
||||||
cached_conv2_list[i],
|
|
||||||
]
|
|
||||||
|
|
||||||
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
|
||||||
for i in range(batch_size):
|
|
||||||
state_list[i].append(cached_embed_left_pad_list[i])
|
|
||||||
|
|
||||||
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
|
|
||||||
for i in range(batch_size):
|
|
||||||
state_list[i].append(processed_lens_list[i])
|
|
||||||
|
|
||||||
return state_list
|
|
||||||
|
|
||||||
|
|
||||||
def streaming_forward(
|
|
||||||
features: Tensor,
|
|
||||||
feature_lens: Tensor,
|
|
||||||
model: nn.Module,
|
|
||||||
states: List[Tensor],
|
|
||||||
chunk_size: int,
|
|
||||||
left_context_len: int,
|
|
||||||
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
|
||||||
"""
|
|
||||||
Returns encoder outputs, output lengths, and updated states.
|
|
||||||
"""
|
|
||||||
cached_embed_left_pad = states[-2]
|
|
||||||
(
|
|
||||||
x,
|
|
||||||
x_lens,
|
|
||||||
new_cached_embed_left_pad,
|
|
||||||
) = model.encoder_embed.streaming_forward(
|
|
||||||
x=features,
|
|
||||||
x_lens=feature_lens,
|
|
||||||
cached_left_pad=cached_embed_left_pad,
|
|
||||||
)
|
|
||||||
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
|
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
|
||||||
|
|
||||||
# processed_mask is used to mask out initial states
|
|
||||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
|
||||||
x.size(0), left_context_len
|
|
||||||
)
|
|
||||||
processed_lens = states[-1] # (batch,)
|
|
||||||
# (batch, left_context_size)
|
|
||||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
|
||||||
# Update processed lengths
|
|
||||||
new_processed_lens = processed_lens + x_lens
|
|
||||||
|
|
||||||
# (batch, left_context_size + chunk_size)
|
|
||||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
||||||
encoder_states = states[:-2]
|
|
||||||
(
|
|
||||||
encoder_out,
|
|
||||||
encoder_out_lens,
|
|
||||||
new_encoder_states,
|
|
||||||
) = model.encoder.streaming_forward(
|
|
||||||
x=x,
|
|
||||||
x_lens=x_lens,
|
|
||||||
states=encoder_states,
|
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
|
||||||
)
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
|
||||||
|
|
||||||
new_states = new_encoder_states + [
|
|
||||||
new_cached_embed_left_pad,
|
|
||||||
new_processed_lens,
|
|
||||||
]
|
|
||||||
return encoder_out, encoder_out_lens, new_states
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_chunk(
|
def decode_one_chunk(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -493,7 +263,7 @@ def decode_one_chunk(
|
|||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
# supervisions["sequence_idx"],
|
# supervisions["sequence_idx"],
|
||||||
list(map(lambda x: x.cut_id, decode_streams)),
|
torch.tensor([index for index, _ in enumerate(decode_streams)]),
|
||||||
torch.div(
|
torch.div(
|
||||||
0,
|
0,
|
||||||
params.subsampling_factor,
|
params.subsampling_factor,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user