diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index b36c02df2..91f8d3380 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import copy import math import warnings @@ -234,7 +235,7 @@ class Conformer(EncoderInterface): if not simulate_streaming: assert ( - decode_states is not None + states is not None ), "Require cache when sending data in streaming mode" assert ( @@ -423,7 +424,7 @@ class ConformerEncoderLayer(nn.Module): # src: [chunk_size, N, F] e.g. [8, 41, 512] key = torch.cat([states[0, ...], src], dim=0) val = key - states[0, ...] = key[-left_context, ...] + states[0, ...] = key[-left_context:, ...] else: assert left_context == 0 @@ -441,14 +442,15 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module + residual = src if not self.training and states is not None: src = torch.cat([states[1, ...], src], dim=0) - states[1, ...] = src[-left_context, ...] + states[1, ...] = src[-left_context:, ...] conv = self.conv_module(src) - conv = conv[-src.size(0) :, :, :] # noqa: E203 + conv = conv[-residual.size(0) :, :, :] # noqa: E203 - src = src + self.dropout(conv) + src = residual + self.dropout(conv) # feed forward module src = src + self.dropout(self.feed_forward(src)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index cbc2baa8a..0acab77c3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -70,6 +70,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -101,6 +102,7 @@ from icefall.utils import ( write_error_stats, ) +LOG_EPS = math.log(1e-10) def get_parser(): parser = argparse.ArgumentParser( @@ -324,6 +326,13 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + if params.simulate_streaming: encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature,