Use torch.stack() to replace torch.cat()

This commit is contained in:
Fangjun Kuang 2022-04-12 15:54:50 +08:00
parent 4cef2728cd
commit 747339a6c1
4 changed files with 40 additions and 27 deletions

View File

@ -32,13 +32,16 @@ class Joiner(nn.Module):
""" """
Args: Args:
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C). Output from the encoder. Its shape is (N, T, s_range, C) for
training and (N, C) for streaming decoding.
decoder_out: decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C). Output from the decoder. Its shape is (N, T, s_range, C) for
training and (N, C) for streaming decoding.
Returns: Returns:
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out logit = encoder_out + decoder_out

View File

@ -51,8 +51,9 @@ def unstack_states(
for li, layer in enumerate(states): for li, layer in enumerate(states):
for s in layer: for s in layer:
s_list = s.unbind(dim=1) s_list = s.unbind(dim=1)
# We will use stack(dim=1) later in stack_states()
for bi, b in enumerate(ans): for bi, b in enumerate(ans):
b[li].append(s_list[bi].unsqueeze(dim=1)) b[li].append(s_list[bi])
return ans return ans
@ -75,15 +76,23 @@ def stack_states(
See the input argument of :func:`unstack_states` for the meaning See the input argument of :func:`unstack_states` for the meaning
of the returned tensor. of the returned tensor.
""" """
batch_size = len(state_list)
ans = [] ans = []
for layer in state_list[0]: for layer in state_list[0]:
# layer is a list of tensors # layer is a list of tensors
ans.append([s for s in layer]) if batch_size > 1:
ans.append([[s] for s in layer])
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
else:
ans.append([s.unsqueeze(1) for s in layer])
for states in state_list[1:]: for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states): for li, layer in enumerate(states):
for si, s in enumerate(layer): for si, s in enumerate(layer):
ans[li][si] = torch.cat([ans[li][si], s], dim=1) ans[li][si].append(s)
if b == batch_size - 1:
ans[li][si] = torch.stack(ans[li][si], dim=1)
# We will use unbind(dim=1) later in unstack_states()
return ans return ans

View File

@ -190,11 +190,7 @@ class StreamingAudioSamples(object):
""" """
ans = [] ans = []
# Note: Either branch is fine. The purpose is to simulate streaming num = [1024] * len(self.samples)
if False:
num = torch.randint(2000, 5000, (len(self.samples),)).tolist()
else:
num = [1024] * len(self.samples)
for i in range(len(self.samples)): for i in range(len(self.samples)):
start = self.cur_indexes[i] start = self.cur_indexes[i]
@ -293,16 +289,17 @@ class StreamList(object):
# has a shape (1, feature_dim) # has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length] chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:] stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0).unsqueeze(0) features = torch.cat(chunk, dim=0)
feature_list.append(features) feature_list.append(features)
stream_list.append(stream) stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0: elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length] chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = [] stream.feature_frames = []
features = torch.cat(chunk, dim=0).unsqueeze(0) features = torch.cat(chunk, dim=0)
features = torch.nn.functional.pad( features = torch.nn.functional.pad(
features, features,
(0, 0, 0, chunk_length - features.size(1)), (0, 0, 0, chunk_length - features.size(0)),
mode="constant",
value=LOG_EPSILON, value=LOG_EPSILON,
) )
feature_list.append(features) feature_list.append(features)
@ -311,7 +308,7 @@ class StreamList(object):
if len(feature_list) == 0: if len(feature_list) == 0:
return None, None return None, None
features = torch.cat(feature_list, dim=0) features = torch.stack(feature_list, dim=0)
return features, stream_list return features, stream_list
@ -346,10 +343,10 @@ def greedy_search(
decoder_out = model.decoder( decoder_out = model.decoder(
decoder_input, decoder_input,
need_pad=False, need_pad=False,
).unsqueeze(1) ).squeeze(1)
# decoder_out is of shape (N, 1, decoder_out_dim) # decoder_out is of shape (N, decoder_out_dim)
else: else:
decoder_out = torch.cat( decoder_out = torch.stack(
[stream.decoder_out for stream in streams], [stream.decoder_out for stream in streams],
dim=0, dim=0,
) )
@ -358,13 +355,12 @@ def greedy_search(
T = encoder_out.size(1) T = encoder_out.size(1)
for t in range(T): for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(current_encoder_out, decoder_out)
# logits'shape (batch_size, 1, 1, vocab_size) # logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist() y = logits.argmax(dim=1).tolist()
emitted = False emitted = False
@ -380,9 +376,9 @@ def greedy_search(
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) )
decoder_out = model.decoder( decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
decoder_input, need_pad=False 1
).unsqueeze(1) )
for k, s in enumerate(streams): for k, s in enumerate(streams):
logging.info( logging.info(
@ -392,7 +388,7 @@ def greedy_search(
decoder_out_list = decoder_out.unbind(dim=0) decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list): for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d.unsqueeze(0) streams[i].decoder_out = d
def process_features( def process_features(
@ -424,6 +420,10 @@ def process_features(
fill_value=features.size(1), fill_value=features.size(1),
device=device, device=device,
) )
# Caution: It has a limitation as it assumes that
# if one of the stream has an empty state, then all other
# streams also have empty states.
if streams[0].states is None: if streams[0].states is None:
states = None states = None
else: else:

View File

@ -60,6 +60,7 @@ class FeatureExtractionStream(object):
# For the RNN-T decoder, it contains the decoder output # For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:] # corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.decoder_out: Optional[torch.Tensor] = None self.decoder_out: Optional[torch.Tensor] = None
# After calling `self.input_finished()`, we set this flag to True # After calling `self.input_finished()`, we set this flag to True