From 026fb22076d21851ff905b674de3ef9caf08db36 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 10 Jun 2022 17:17:03 +0800 Subject: [PATCH] fix bugs, no grad, and num_proccessed_frames --- .../emformer.py | 52 ++++++++++--------- .../streaming_decode.py | 23 ++++---- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 9f2a977e9..00e8a2489 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1564,29 +1564,31 @@ class EmformerEncoder(nn.Module): ) # calcualte padding mask to mask out initial zero caches - # chunk_mask = make_pad_mask(output_lengths).to(x.device) - # memory_mask = ( - # (past_lens // self.chunk_length).view(x.size(1), 1) - # <= torch.arange(self.memory_size, device=x.device).expand( - # x.size(1), self.memory_size - # ) - # ).flip(1) - # left_context_mask = ( - # past_lens.view(x.size(1), 1) - # <= torch.arange(self.left_context_length, device=x.device).expand( - # x.size(1), self.left_context_length - # ) - # ).flip(1) - # right_context_mask = torch.zeros( - # x.size(1), - # self.right_context_length, - # dtype=torch.bool, - # device=x.device, - # ) - # padding_mask = torch.cat( - # [memory_mask, left_context_mask, right_context_mask, chunk_mask], - # dim=1, - # ) + chunk_mask = make_pad_mask(output_lengths).to(x.device) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + memory_mask = ( + (num_processed_frames // self.chunk_length).view(x.size(1), 1) + <= torch.arange(self.memory_size, device=x.device).expand( + x.size(1), self.memory_size + ) + ).flip(1) + left_context_mask = ( + num_processed_frames.view(x.size(1), 1) + <= torch.arange(self.left_context_length, device=x.device).expand( + x.size(1), self.left_context_length + ) + ).flip(1) + right_context_mask = torch.zeros( + x.size(1), + self.right_context_length, + dtype=torch.bool, + device=x.device, + ) + padding_mask = torch.cat( + [memory_mask, right_context_mask, left_context_mask, chunk_mask], + dim=1, + ) output = utterance output_attn_caches: List[List[torch.Tensor]] = [] @@ -1602,7 +1604,7 @@ class EmformerEncoder(nn.Module): output, right_context, memory, - # padding_mask=padding_mask, + padding_mask=padding_mask, attn_cache=attn_caches[layer_idx], conv_cache=conv_caches[layer_idx], ) @@ -1765,6 +1767,8 @@ class Emformer(EncoderInterface): x_lens -= 2 assert x.size(0) == x_lens.max().item() + num_processed_frames = num_processed_frames >> 2 + output, output_lengths, output_states = self.encoder.infer( x, x_lens, num_processed_frames, states ) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index bf22e3f2d..35a909397 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -17,6 +17,7 @@ # limitations under the License. import argparse +import copy import logging import math import warnings @@ -261,11 +262,14 @@ def decode_one_chunk( num_processed_frames_list = [] for stream in streams: + # We should first get `stream.num_processed_frames` + # before calling `stream.get_feature_chunk()` + # since `stream.num_processed_frames` would be updated + num_processed_frames_list.append(stream.num_processed_frames) feature, feature_len = stream.get_feature_chunk() feature_list.append(feature) feature_len_list.append(feature_len) state_list.append(stream.states) - num_processed_frames_list.append(stream.num_processed_frames) features = pad_sequence( feature_list, batch_first=True, padding_value=LOG_EPSILON @@ -288,7 +292,6 @@ def decode_one_chunk( mode="constant", value=LOG_EPSILON, ) - # print(features.shape) # stack states of all streams states = stack_states(state_list) @@ -300,12 +303,6 @@ def decode_one_chunk( ) encoder_out = model.joiner.encoder_proj(encoder_out) - # update cached states of each stream - state_list = unstack_states(states) - assert len(streams) == len(state_list) - for i, s in enumerate(state_list): - streams[i].states = s - if params.decoding_method == "greedy_search": greedy_search( model=model, @@ -325,6 +322,11 @@ def decode_one_chunk( f"Unsupported decoding method: {params.decoding_method}" ) + # update cached states of each stream + state_list = unstack_states(states) + for i, s in enumerate(state_list): + streams[i].states = s + finished_streams = [i for i, stream in enumerate(streams) if stream.done] return finished_streams @@ -399,9 +401,7 @@ def decode_dataset( sp.decode(streams[i].decoding_result()).split(), ) ) - print(decode_results[-1]) del streams[i] - # print("delete", i, len(streams)) if num % log_interval == 0: logging.info(f"Cuts processed until now is {num}.") @@ -470,9 +470,10 @@ def save_results( for key, val in test_set_wers: s += "{}\t{}{}\n".format(key, val, note) note = "" - logging.info(s) @ torch.no_grad() + logging.info(s) +@torch.no_grad() def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser)