Use torch.stack() to replace torch.cat()

This commit is contained in:
Fangjun Kuang 2022-04-12 15:54:50 +08:00
parent 4cef2728cd
commit 747339a6c1
4 changed files with 40 additions and 27 deletions

View File

@ -32,13 +32,16 @@ class Joiner(nn.Module):
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
Output from the encoder. Its shape is (N, T, s_range, C) for
training and (N, C) for streaming decoding.
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
Output from the decoder. Its shape is (N, T, s_range, C) for
training and (N, C) for streaming decoding.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out

View File

@ -51,8 +51,9 @@ def unstack_states(
for li, layer in enumerate(states):
for s in layer:
s_list = s.unbind(dim=1)
# We will use stack(dim=1) later in stack_states()
for bi, b in enumerate(ans):
b[li].append(s_list[bi].unsqueeze(dim=1))
b[li].append(s_list[bi])
return ans
@ -75,15 +76,23 @@ def stack_states(
See the input argument of :func:`unstack_states` for the meaning
of the returned tensor.
"""
batch_size = len(state_list)
ans = []
for layer in state_list[0]:
# layer is a list of tensors
ans.append([s for s in layer])
if batch_size > 1:
ans.append([[s] for s in layer])
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
else:
ans.append([s.unsqueeze(1) for s in layer])
for states in state_list[1:]:
for b, states in enumerate(state_list[1:], 1):
for li, layer in enumerate(states):
for si, s in enumerate(layer):
ans[li][si] = torch.cat([ans[li][si], s], dim=1)
ans[li][si].append(s)
if b == batch_size - 1:
ans[li][si] = torch.stack(ans[li][si], dim=1)
# We will use unbind(dim=1) later in unstack_states()
return ans

View File

@ -190,11 +190,7 @@ class StreamingAudioSamples(object):
"""
ans = []
# Note: Either branch is fine. The purpose is to simulate streaming
if False:
num = torch.randint(2000, 5000, (len(self.samples),)).tolist()
else:
num = [1024] * len(self.samples)
num = [1024] * len(self.samples)
for i in range(len(self.samples)):
start = self.cur_indexes[i]
@ -293,16 +289,17 @@ class StreamList(object):
# has a shape (1, feature_dim)
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = stream.feature_frames[segment_length:]
features = torch.cat(chunk, dim=0).unsqueeze(0)
features = torch.cat(chunk, dim=0)
feature_list.append(features)
stream_list.append(stream)
elif stream.done and len(stream.feature_frames) > 0:
chunk = stream.feature_frames[:chunk_length]
stream.feature_frames = []
features = torch.cat(chunk, dim=0).unsqueeze(0)
features = torch.cat(chunk, dim=0)
features = torch.nn.functional.pad(
features,
(0, 0, 0, chunk_length - features.size(1)),
(0, 0, 0, chunk_length - features.size(0)),
mode="constant",
value=LOG_EPSILON,
)
feature_list.append(features)
@ -311,7 +308,7 @@ class StreamList(object):
if len(feature_list) == 0:
return None, None
features = torch.cat(feature_list, dim=0)
features = torch.stack(feature_list, dim=0)
return features, stream_list
@ -346,10 +343,10 @@ def greedy_search(
decoder_out = model.decoder(
decoder_input,
need_pad=False,
).unsqueeze(1)
# decoder_out is of shape (N, 1, decoder_out_dim)
).squeeze(1)
# decoder_out is of shape (N, decoder_out_dim)
else:
decoder_out = torch.cat(
decoder_out = torch.stack(
[stream.decoder_out for stream in streams],
dim=0,
)
@ -358,13 +355,12 @@ def greedy_search(
T = encoder_out.size(1)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out)
# logits'shape (batch_size, 1, 1, vocab_size)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
@ -380,9 +376,9 @@ def greedy_search(
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input, need_pad=False
).unsqueeze(1)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
1
)
for k, s in enumerate(streams):
logging.info(
@ -392,7 +388,7 @@ def greedy_search(
decoder_out_list = decoder_out.unbind(dim=0)
for i, d in enumerate(decoder_out_list):
streams[i].decoder_out = d.unsqueeze(0)
streams[i].decoder_out = d
def process_features(
@ -424,6 +420,10 @@ def process_features(
fill_value=features.size(1),
device=device,
)
# Caution: It has a limitation as it assumes that
# if one of the stream has an empty state, then all other
# streams also have empty states.
if streams[0].states is None:
states = None
else:

View File

@ -60,6 +60,7 @@ class FeatureExtractionStream(object):
# For the RNN-T decoder, it contains the decoder output
# corresponding to the decoder input self.hyp.ys[-context_size:]
# Its shape is (decoder_out_dim,)
self.decoder_out: Optional[torch.Tensor] = None
# After calling `self.input_finished()`, we set this flag to True