From 747339a6c163626067f076099dcaf99ff586997a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Apr 2022 15:54:50 +0800 Subject: [PATCH] Use torch.stack() to replace torch.cat() --- .../ASR/pruned_transducer_stateless/joiner.py | 9 +++-- .../ASR/transducer_emformer/emformer.py | 17 ++++++-- .../transducer_emformer/streaming_decode.py | 40 +++++++++---------- .../streaming_feature_extractor.py | 1 + 4 files changed, 40 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py index 7c5a93a86..fbb30e057 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py @@ -32,13 +32,16 @@ class Joiner(nn.Module): """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). + Output from the encoder. Its shape is (N, T, s_range, C) for + training and (N, C) for streaming decoding. decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). + Output from the decoder. Its shape is (N, T, s_range, C) for + training and (N, C) for streaming decoding. Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.ndim in (2, 4) assert encoder_out.shape == decoder_out.shape logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/transducer_emformer/emformer.py b/egs/librispeech/ASR/transducer_emformer/emformer.py index 80849afd6..0029b42af 100644 --- a/egs/librispeech/ASR/transducer_emformer/emformer.py +++ b/egs/librispeech/ASR/transducer_emformer/emformer.py @@ -51,8 +51,9 @@ def unstack_states( for li, layer in enumerate(states): for s in layer: s_list = s.unbind(dim=1) + # We will use stack(dim=1) later in stack_states() for bi, b in enumerate(ans): - b[li].append(s_list[bi].unsqueeze(dim=1)) + b[li].append(s_list[bi]) return ans @@ -75,15 +76,23 @@ def stack_states( See the input argument of :func:`unstack_states` for the meaning of the returned tensor. """ + batch_size = len(state_list) ans = [] for layer in state_list[0]: # layer is a list of tensors - ans.append([s for s in layer]) + if batch_size > 1: + ans.append([[s] for s in layer]) + # Note: We will stack ans[layer][s][] later to get ans[layer][s] + else: + ans.append([s.unsqueeze(1) for s in layer]) - for states in state_list[1:]: + for b, states in enumerate(state_list[1:], 1): for li, layer in enumerate(states): for si, s in enumerate(layer): - ans[li][si] = torch.cat([ans[li][si], s], dim=1) + ans[li][si].append(s) + if b == batch_size - 1: + ans[li][si] = torch.stack(ans[li][si], dim=1) + # We will use unbind(dim=1) later in unstack_states() return ans diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py index 93ca43ff3..bb71310b7 100755 --- a/egs/librispeech/ASR/transducer_emformer/streaming_decode.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_decode.py @@ -190,11 +190,7 @@ class StreamingAudioSamples(object): """ ans = [] - # Note: Either branch is fine. The purpose is to simulate streaming - if False: - num = torch.randint(2000, 5000, (len(self.samples),)).tolist() - else: - num = [1024] * len(self.samples) + num = [1024] * len(self.samples) for i in range(len(self.samples)): start = self.cur_indexes[i] @@ -293,16 +289,17 @@ class StreamList(object): # has a shape (1, feature_dim) chunk = stream.feature_frames[:chunk_length] stream.feature_frames = stream.feature_frames[segment_length:] - features = torch.cat(chunk, dim=0).unsqueeze(0) + features = torch.cat(chunk, dim=0) feature_list.append(features) stream_list.append(stream) elif stream.done and len(stream.feature_frames) > 0: chunk = stream.feature_frames[:chunk_length] stream.feature_frames = [] - features = torch.cat(chunk, dim=0).unsqueeze(0) + features = torch.cat(chunk, dim=0) features = torch.nn.functional.pad( features, - (0, 0, 0, chunk_length - features.size(1)), + (0, 0, 0, chunk_length - features.size(0)), + mode="constant", value=LOG_EPSILON, ) feature_list.append(features) @@ -311,7 +308,7 @@ class StreamList(object): if len(feature_list) == 0: return None, None - features = torch.cat(feature_list, dim=0) + features = torch.stack(feature_list, dim=0) return features, stream_list @@ -346,10 +343,10 @@ def greedy_search( decoder_out = model.decoder( decoder_input, need_pad=False, - ).unsqueeze(1) - # decoder_out is of shape (N, 1, decoder_out_dim) + ).squeeze(1) + # decoder_out is of shape (N, decoder_out_dim) else: - decoder_out = torch.cat( + decoder_out = torch.stack( [stream.decoder_out for stream in streams], dim=0, ) @@ -358,13 +355,12 @@ def greedy_search( T = encoder_out.size(1) for t in range(T): - current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t] + # current_encoder_out's shape: (batch_size, encoder_out_dim) logits = model.joiner(current_encoder_out, decoder_out) - # logits'shape (batch_size, 1, 1, vocab_size) + # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) assert logits.ndim == 2, logits.shape y = logits.argmax(dim=1).tolist() emitted = False @@ -380,9 +376,9 @@ def greedy_search( device=device, dtype=torch.int64, ) - decoder_out = model.decoder( - decoder_input, need_pad=False - ).unsqueeze(1) + decoder_out = model.decoder(decoder_input, need_pad=False).squeeze( + 1 + ) for k, s in enumerate(streams): logging.info( @@ -392,7 +388,7 @@ def greedy_search( decoder_out_list = decoder_out.unbind(dim=0) for i, d in enumerate(decoder_out_list): - streams[i].decoder_out = d.unsqueeze(0) + streams[i].decoder_out = d def process_features( @@ -424,6 +420,10 @@ def process_features( fill_value=features.size(1), device=device, ) + + # Caution: It has a limitation as it assumes that + # if one of the stream has an empty state, then all other + # streams also have empty states. if streams[0].states is None: states = None else: diff --git a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py index e9f40576f..b20f6502f 100644 --- a/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py +++ b/egs/librispeech/ASR/transducer_emformer/streaming_feature_extractor.py @@ -60,6 +60,7 @@ class FeatureExtractionStream(object): # For the RNN-T decoder, it contains the decoder output # corresponding to the decoder input self.hyp.ys[-context_size:] + # Its shape is (decoder_out_dim,) self.decoder_out: Optional[torch.Tensor] = None # After calling `self.input_finished()`, we set this flag to True