From 7f097204034ae0af5afe43a281f550193cf8de15 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 9 Jun 2022 20:37:16 +0800 Subject: [PATCH] refactor streaming decoding --- .../emformer.py | 131 ++-- .../stream.py | 28 +- .../streaming_decode.py | 694 +++++++++--------- .../train.py | 2 +- 4 files changed, 421 insertions(+), 434 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 956278546..9f2a977e9 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -41,8 +41,8 @@ LOG_EPSILON = math.log(1e-10) def unstack_states( - states, -) -> List[List[List[torch.Tensor]]]: + states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] +) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]: # TODO: modify doc """Unstack the emformer state corresponding to a batch of utterances into a list of states, were the i-th entry is the state from the i-th @@ -50,18 +50,14 @@ def unstack_states( Args: states: - A list-of-list of tensors. ``len(states)`` equals to number of - layers in the emformer. ``states[i]]`` contains the states for - the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape - ``(T, N, C)`` or a 2-D tensor of shape ``(C, N)`` + A list-of-list of tensors. + ``len(states[0])`` and ``len(states[1])`` eqaul to number of layers. """ - past_lens, attn_caches, conv_caches = states - batch_size = past_lens.size(0) + attn_caches, conv_caches = states + batch_size = conv_caches[0].size(0) num_layers = len(attn_caches) - list_past_len = past_lens.tolist() - list_attn_caches = [None] * batch_size for i in range(batch_size): list_attn_caches[i] = [[] for _ in range(num_layers)] @@ -81,14 +77,14 @@ def unstack_states( ans = [None] * batch_size for i in range(batch_size): - ans[i] = [list_past_len[i], list_attn_caches[i], list_conv_caches[i]] + ans[i] = [list_attn_caches[i], list_conv_caches[i]] return ans def stack_states( - state_list, -) -> List[List[torch.Tensor]]: + state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]] +) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]: # TODO: modify doc """Stack list of emformer states that correspond to separate utterances into a single emformer state so that it can be used as an input for @@ -108,18 +104,15 @@ def stack_states( """ batch_size = len(state_list) - past_lens = [states[0] for states in state_list] - past_lens = torch.tensor([past_lens]) - attn_caches = [] - for layer in state_list[0][1]: + for layer in state_list[0][0]: if batch_size > 1: # Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s] # noqa attn_caches.append([[s] for s in layer]) else: attn_caches.append([s.unsqueeze(1) for s in layer]) for b, states in enumerate(state_list[1:], 1): - for li, layer in enumerate(states[1]): + for li, layer in enumerate(states[0]): for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: @@ -128,19 +121,19 @@ def stack_states( ) conv_caches = [] - for layer in state_list[0][2]: + for layer in state_list[0][1]: if batch_size > 1: # Note: We will stack conv_caches[layer][] later to get attn_caches[layer] # noqa conv_caches.append([layer]) else: conv_caches.append(layer.unsqueeze(0)) for b, states in enumerate(state_list[1:], 1): - for li, layer in enumerate(states[2]): + for li, layer in enumerate(states[1]): conv_caches[li].append(layer) if b == batch_size - 1: conv_caches[li] = torch.stack(conv_caches[li], dim=0) - return [past_lens, attn_caches, conv_caches] + return [attn_caches, conv_caches] class ConvolutionModule(nn.Module): @@ -1489,13 +1482,12 @@ class EmformerEncoder(nn.Module): self, x: torch.Tensor, lengths: torch.Tensor, - states: List[ - torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor] - ], + num_processed_frames: torch.Tensor, + states: Tuple[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]], ) -> Tuple[ torch.Tensor, torch.Tensor, - List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]], + Tuple[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]], ]: """Forward pass for streaming inference. @@ -1526,10 +1518,9 @@ class EmformerEncoder(nn.Module): right_context at the end. - updated states from current chunk's computation. """ - past_lens = states[0] - assert past_lens.shape == (x.size(1),), past_lens.shape + assert num_processed_frames.shape == (x.size(1),) - attn_caches = states[1] + attn_caches = states[0] assert len(attn_caches) == self.num_encoder_layers, len(attn_caches) for i in range(len(attn_caches)): assert attn_caches[i][0].shape == ( @@ -1548,24 +1539,23 @@ class EmformerEncoder(nn.Module): self.d_model, ), attn_caches[i][2].shape - conv_caches = states[2] + conv_caches = states[1] assert len(conv_caches) == self.num_encoder_layers, len(conv_caches) for i in range(len(conv_caches)): assert conv_caches[i].shape == ( x.size(1), self.d_model, - self.cnn_module_kernel, + self.cnn_module_kernel - 1, ), conv_caches[i].shape - assert x.size(0) == self.chunk_length + self.right_context_length, ( - "Per configured chunk_length and right_context_length, " - f"expected size of {self.chunk_length + self.right_context_length} " - f"for dimension 1 of x, but got {x.size(1)}." - ) + # assert x.size(0) == self.chunk_length + self.right_context_length, ( + # "Per configured chunk_length and right_context_length, " + # f"expected size of {self.chunk_length + self.right_context_length} " + # f"for dimension 1 of x, but got {x.size(0)}." + # ) - right_context_start_idx = x.size(0) - self.right_context_length - right_context = x[right_context_start_idx:] - utterance = x[:right_context_start_idx] + right_context = x[-self.right_context_length :] + utterance = x[: -self.right_context_length] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) memory = ( self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) @@ -1574,29 +1564,29 @@ class EmformerEncoder(nn.Module): ) # calcualte padding mask to mask out initial zero caches - chunk_mask = make_pad_mask(output_lengths) - memory_mask = ( - (past_lens // self.chunk_length).view(x.size(1), 1) - <= torch.arange(self.memory_size, device=x.device).expand( - x.size(1), self.memory_size - ) - ).flip(1) - left_context_mask = ( - past_lens.view(x.size(1), 1) - <= torch.arange(self.left_context_length, device=x.device).expand( - x.size(1), self.left_context_length - ) - ).flip(1) - right_context_mask = torch.zeros( - x.size(1), - self.right_context_length, - dtype=torch.bool, - device=x.device, - ) - padding_mask = torch.cat( - [memory_mask, left_context_mask, right_context_mask, chunk_mask], - dim=1, - ) + # chunk_mask = make_pad_mask(output_lengths).to(x.device) + # memory_mask = ( + # (past_lens // self.chunk_length).view(x.size(1), 1) + # <= torch.arange(self.memory_size, device=x.device).expand( + # x.size(1), self.memory_size + # ) + # ).flip(1) + # left_context_mask = ( + # past_lens.view(x.size(1), 1) + # <= torch.arange(self.left_context_length, device=x.device).expand( + # x.size(1), self.left_context_length + # ) + # ).flip(1) + # right_context_mask = torch.zeros( + # x.size(1), + # self.right_context_length, + # dtype=torch.bool, + # device=x.device, + # ) + # padding_mask = torch.cat( + # [memory_mask, left_context_mask, right_context_mask, chunk_mask], + # dim=1, + # ) output = utterance output_attn_caches: List[List[torch.Tensor]] = [] @@ -1612,19 +1602,14 @@ class EmformerEncoder(nn.Module): output, right_context, memory, - padding_mask=padding_mask, + # padding_mask=padding_mask, attn_cache=attn_caches[layer_idx], conv_cache=conv_caches[layer_idx], ) output_attn_caches.append(output_attn_cache) output_conv_caches.append(output_conv_cache) - output_past_lens = past_lens + output_lengths - output_states = [ - output_past_lens, - output_attn_caches, - output_conv_caches, - ] + output_states = [output_attn_caches, output_conv_caches] return output, output_lengths, output_states @@ -1738,6 +1723,7 @@ class Emformer(EncoderInterface): self, x: torch.Tensor, x_lens: torch.Tensor, + num_processed_frames: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: """Forward pass for streaming inference. @@ -1770,16 +1756,17 @@ class Emformer(EncoderInterface): - updated states from current chunk's computation. """ x = self.encoder_embed(x) + # drop the first and last frames + x = x[:, 1:-1, :] x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # Caution: We assume the subsampling factor is 4! - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - x_lens = ((x_lens - 1) // 2 - 1) // 2 + x_lens = (((x_lens - 1) >> 1) - 1) >> 1 + x_lens -= 2 assert x.size(0) == x_lens.max().item() output, output_lengths, output_states = self.encoder.infer( - x, x_lens, states + x, x_lens, num_processed_frames, states ) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index b7293cac6..6c7c52df4 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -18,7 +18,7 @@ import math from typing import List, Optional, Tuple import torch -from beam_search import HypothesisList +from beam_search import Hypothesis, HypothesisList from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from icefall.utils import AttributeDict @@ -48,6 +48,7 @@ class Stream(object): self, params: AttributeDict, device: torch.device = torch.device("cpu"), + LOG_EPS: float = math.log(1e-10), ) -> None: """ Args: @@ -57,11 +58,14 @@ class Stream(object): The device to run this stream. """ self.device = device + self.LOG_EPS = LOG_EPS # Containing attention caches and convolution caches - self.states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = None + self.states: Optional[ + Tuple[List[List[torch.Tensor]], List[torch.Tensor]] + ] = None # Initailize zero states. - self.init_states() + self.init_states(params) # It use different attributes for different decoding methods. self.context_size = params.context_size @@ -70,6 +74,12 @@ class Stream(object): self.hyp = [params.blank_id] * params.context_size elif params.decoding_method == "modified_beam_search": self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -77,7 +87,7 @@ class Stream(object): self.ground_truth: str = "" - self.feature: torch.Tensor = None + self.feature: Optional[torch.Tensor] = None # Make sure all feature frames can be used. # Add 2 here since we will drop the first and last after subsampling. self.chunk_length = params.chunk_length @@ -91,14 +101,14 @@ class Stream(object): self._done = False def set_feature(self, feature: torch.Tensor) -> None: - assert feature.dim == 2, feature.dim + assert feature.dim() == 2, feature.dim() self.num_frames = feature.size(0) # tail padding self.feature = torch.nn.functional.pad( feature, (0, 0, 0, self.pad_length), mode="constant", - value=math.log(1e-10), + value=self.LOG_EPS, ) def set_ground_truth(self, ground_truth: str) -> None: @@ -140,9 +150,11 @@ class Stream(object): ) ret_length = update_length + self.pad_length - ret_feature = self.feature[:ret_length] + ret_feature = self.feature[ + self.num_processed_frames : self.num_processed_frames + ret_length + ] # Cut off used frames. - self.feature = self.feature[update_length:] + # self.feature = self.feature[update_length:] self.num_processed_frames += update_length if self.num_processed_frames >= self.num_frames: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index 80c54d384..9234677a1 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -18,9 +18,10 @@ import argparse import logging +import math import warnings from pathlib import Path -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import k2 from lhotse import CutSet @@ -31,15 +32,24 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import Hypothesis, HypothesisList, get_hyps_shape from emformer import LOG_EPSILON, stack_states, unstack_states -from streaming_feature_extractor import Stream +from kaldifeat import Fbank, FbankOptions +from stream import Stream +from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, + average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) -from icefall.utils import AttributeDict, setup_logger +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) def get_parser(): @@ -55,6 +65,16 @@ def get_parser(): "Note: Epoch counts from 0.", ) + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, @@ -65,14 +85,14 @@ def get_parser(): ) parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -172,52 +192,39 @@ def get_parser(): def greedy_search( model: nn.Module, - streams: List[Stream], encoder_out: torch.Tensor, - sp: spm.SentencePieceProcessor, -): - """ - Args: - model: - The RNN-T model. - streams: - A list of stream objects. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - sp: - The BPE model. - """ + streams: List[Stream], +) -> List[List[int]]: + assert len(streams) == encoder_out.size(0) assert encoder_out.ndim == 3 blank_id = model.decoder.blank_id context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device T = encoder_out.size(1) - if streams[0].decoder_out is None: - for stream in streams: - stream.hyp = [blank_id] * context_size - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) - # decoder_out is of shape (N, decoder_out_dim) - else: - decoder_out = torch.stack( - [stream.decoder_out for stream in streams], - dim=0, - ) + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # logging.info(f"decoder_out shape : {decoder_out.shape}") for t in range(T): - current_encoder_out = encoder_out[:, t] - # current_encoder_out's shape: (batch_size, encoder_out_dim) + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) assert logits.ndim == 2, logits.shape y = logits.argmax(dim=1).tolist() @@ -236,227 +243,64 @@ def greedy_search( decoder_out = model.decoder( decoder_input, need_pad=False, - ).squeeze(1) - - for k, stream in enumerate(streams): - result = sp.decode(stream.decoding_result()) - logging.info(f"Partial result {k}:\n{result}") - - decoder_out_list = decoder_out.unbind(dim=0) - for i, d in enumerate(decoder_out_list): - streams[i].decoder_out = d - - -def modified_beam_search( - model: nn.Module, - streams: List[Stream], - encoder_out: torch.Tensor, - sp: spm.SentencePieceProcessor, - beam: int = 4, -): - """ - Args: - model: - The RNN-T model. - streams: - A list of stream objects. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - sp: - The BPE model. - beam: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - batch_size = len(streams) - T = encoder_out.size(1) - - for stream in streams: - if len(stream.hyps) == 0: - stream.hyps.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) ) - B = [stream.hyps for stream in streams] - for t in range(T): - current_encoder_out = encoder_out[:, t] - # current_encoder_out's shape: (batch_size, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) - # decoder_out is of shape (num_hyps, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner(current_encoder_out, decoder_out) - # logits is of shape (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - streams[i].hyps = B[i] - result = sp.decode(streams[i].decoding_result()) - logging.info(f"Partial result {i}:\n{result}") + decoder_out = model.joiner.decoder_proj(decoder_out) -def build_batch( - decode_steams: List[Stream], - chunk_length: int, - segment_length: int, -) -> Tuple[ - Optional[torch.Tensor], - Optional[torch.tensor], - Optional[List[Stream]], -]: - """ - Args: - chunk_length: - Number of frames for each chunk. It equals to - ``segment_length + right_context_length``. - segment_length - Number of frames for each segment. - Returns: - Return a tuple containing: - - features, a 3-D tensor of shape ``(num_active_streams, T, C)`` - - active_streams, a list of active streams. We say a stream is - active when it has enough feature frames to be fed into the - encoder model. - """ - feature_list = [] - length_list = [] - stream_list = [] - for stream in decode_steams: - if len(stream.feature_frames) >= chunk_length: - # this_chunk is a list of tensors, each of which - # has a shape (1, feature_dim) - chunk = stream.feature_frames[:chunk_length] - stream.feature_frames = stream.feature_frames[segment_length:] - features = torch.cat(chunk, dim=0) - feature_list.append(features) - length_list.append(chunk_length) - stream_list.append(stream) - elif stream.done and len(stream.feature_frames) > 0: - chunk = stream.feature_frames[:chunk_length] - stream.feature_frames = [] - features = torch.cat(chunk, dim=0) - length_list.append(features.size(0)) - features = torch.nn.functional.pad( - features, - (0, 0, 0, chunk_length - features.size(0)), - mode="constant", - value=LOG_EPSILON, - ) - feature_list.append(features) - stream_list.append(stream) - - if len(feature_list) == 0: - return None, None, None - - features = torch.stack(feature_list, dim=0) - lengths = torch.cat(length_list) - return features, lengths, stream_list - - -def process_features( +def decode_one_chunk( model: nn.Module, - features: torch.Tensor, - feature_lens: torch.Tensor, streams: List[Stream], params: AttributeDict, sp: spm.SentencePieceProcessor, -) -> None: - """Process features for each stream in parallel. +) -> List[int]: + device = next(model.parameters()).device - Args: - model: - The RNN-T model. - features: - A 3-D tensor of shape (N, T, C). - streams: - A list of streams of size (N,). - params: - It is the return value of :func:`get_params`. - sp: - The BPE model. - """ - assert features.ndim == 3 - assert features.size(0) == len(streams) - assert feature_lens.size(0) == len(streams) + feature_list = [] + feature_len_list = [] + state_list = [] + num_processed_frames_list = [] - device = model.device - features = features.to(device) + for stream in streams: + feature, feature_len = stream.get_feature_chunk() + feature_list.append(feature) + feature_len_list.append(feature_len) + state_list.append(stream.states) + num_processed_frames_list.append(stream.num_processed_frames) - state_list = [stream.states for stream in streams] + features = pad_sequence( + feature_list, batch_first=True, padding_value=LOG_EPSILON + ).to(device) + feature_lens = torch.tensor(feature_len_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) + + # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPSILON, + ) + # print(features.shape) + # stack states of all streams states = stack_states(state_list) encoder_out, encoder_out_lens, states = model.encoder.infer( - features, - feature_lens, - states, + x=features, + x_lens=feature_lens, + states=states, + num_processed_frames=num_processed_frames, ) + encoder_out = model.joiner.encoder_proj(encoder_out) + # update cached states of each stream state_list = unstack_states(states) for i, s in enumerate(state_list): streams[i].states = s @@ -466,26 +310,47 @@ def process_features( model=model, streams=streams, encoder_out=encoder_out, - sp=sp, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=streams, - encoder_out=encoder_out, - sp=sp, - beam=params.beam_size, ) + # elif params.decoding_method == "modified_beam_search": + # modified_beam_search( + # model=model, + # streams=streams, + # encoder_out=encoder_out, + # sp=sp, + # beam=params.beam_size, + # ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) + finished_streams = [i for i, stream in enumerate(streams) if stream.done] + return finished_streams + + +def create_streaming_feature_extractor() -> Fbank: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return Fbank(opts) + def decode_dataset( - params: AttributeDict, cuts: CutSet, model: nn.Module, + params: AttributeDict, sp: spm.SentencePieceProcessor, ): """Decode dataset. @@ -493,72 +358,126 @@ def decode_dataset( """ device = next(model.parameters()).device - # number of frames before subsampling - segment_length = model.encoder.segment_length - right_context_length = model.encoder.right_context_length - # 5 = 3 + 2 - # 1) add 3 here since the subsampling method is using - # ((len - 1) // 2 - 1) // 2) - # 2) add 2 here we will drop first and last frame after subsampling - chunk_length = (segment_length + 5) + right_context_length + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 300 decode_results = [] streams = [] for num, cut in enumerate(cuts): + # Each utterance has a Stream. + stream = Stream(params=params, device=device, LOG_EPS=LOG_EPSILON) + audio: np.ndarray = cut.load_audio() # audio.shape: (1, num_samples) assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype - # The trained model is using normalized samples assert audio.max() <= 1, "Should be normalized to [-1, 1])" samples = torch.from_numpy(audio).squeeze(0) + fbank = create_streaming_feature_extractor() + feature = fbank(samples) + stream.set_feature(feature) + stream.set_ground_truth(cut.supervisions[0].text) - # Each uttetance has a Stream - stream = Stream( - params=params, - audio_sample=samples, - ground_truth=cut.supervisions[0].text, - device=device, - ) streams.append(stream) while len(streams) >= params.num_decode_streams: - for stream in streams: - stream.accept_waveform() - - # try to build batch - features, active_streams = build_batch( - chunk_length=chunk_length, - segment_length=segment_length, + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + sp=sp, ) - if features is not None: - process_features( - model=model, - features=features, - streams=active_streams, - params=params, - sp=sp, - ) - new_streams = [] - for stream in streams: - if stream.done: - decode_results.append( - ( - stream.ground_truth.split(), - sp.decode(stream.decoding_result()).split(), - ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), ) - else: - new_streams.append(stream) - del streams - streams = new_streams + ) + print(decode_results[-1]) + del streams[i] + # print("delete", i, len(streams)) + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + while len(streams) > 0: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + sp=sp, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if params.decoding_method == "greedy_search": + return {"greedy_search": decode_results} + else: + return {f"beam_size_{params.beam_size}": decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=sorted(results)) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) @ torch.no_grad() -@torch.no_grad() def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) @@ -571,6 +490,32 @@ def main(): # Note: params.decoding_method is currently not used. params.res_dir = params.exp_dir / "streaming" / params.decoding_method + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-length-{params.chunk_length}" + params.suffix += f"-left-context-length-{params.left_context_length}" + params.suffix += f"-right-context-length-{params.right_context_length}" + params.suffix += f"-memory-size-{params.memory_size}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-streaming-decode") logging.info("Decoding started") @@ -595,24 +540,83 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) - model.to(device) model.eval() model.device = device @@ -622,42 +626,26 @@ def main(): librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() - batch_size = 3 + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] - ground_truth = [] - batched_samples = [] - for num, cut in enumerate(test_clean_cuts): - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + model=model, + params=params, + sp=sp, + ) - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) - samples = torch.from_numpy(audio).squeeze(0) - - # batched_samples.append(samples) - # ground_truth.append(cut.supervisions[0].text) - - if len(batched_samples) >= batch_size: - decoded_results = decode_batch( - batched_samples=batched_samples, - model=model, - params=params, - sp=sp, - ) - s = "\n" - for i, (hyp, ref) in enumerate(zip(decoded_results, ground_truth)): - s += f"hyp {i}:\n{hyp}\n" - s += f"ref {i}:\n{ref}\n\n" - logging.info(s) - batched_samples = [] - ground_truth = [] - # break after processing the first batch for test purposes - break + logging.info("Done!") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index 6d0b14b14..507f19c1b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -449,7 +449,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: cnn_module_kernel=params.cnn_module_kernel, left_context_length=params.left_context_length, right_context_length=params.right_context_length, - max_memory_size=params.memory_size, + memory_size=params.memory_size, ) return encoder