mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 22:54:18 +00:00
Use torch.stack() to replace torch.cat()
This commit is contained in:
parent
4cef2728cd
commit
747339a6c1
@ -32,13 +32,16 @@ class Joiner(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
Output from the encoder. Its shape is (N, T, s_range, C) for
|
||||
training and (N, C) for streaming decoding.
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
Output from the decoder. Its shape is (N, T, s_range, C) for
|
||||
training and (N, C) for streaming decoding.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
|
@ -51,8 +51,9 @@ def unstack_states(
|
||||
for li, layer in enumerate(states):
|
||||
for s in layer:
|
||||
s_list = s.unbind(dim=1)
|
||||
# We will use stack(dim=1) later in stack_states()
|
||||
for bi, b in enumerate(ans):
|
||||
b[li].append(s_list[bi].unsqueeze(dim=1))
|
||||
b[li].append(s_list[bi])
|
||||
return ans
|
||||
|
||||
|
||||
@ -75,15 +76,23 @@ def stack_states(
|
||||
See the input argument of :func:`unstack_states` for the meaning
|
||||
of the returned tensor.
|
||||
"""
|
||||
batch_size = len(state_list)
|
||||
ans = []
|
||||
for layer in state_list[0]:
|
||||
# layer is a list of tensors
|
||||
ans.append([s for s in layer])
|
||||
if batch_size > 1:
|
||||
ans.append([[s] for s in layer])
|
||||
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
|
||||
else:
|
||||
ans.append([s.unsqueeze(1) for s in layer])
|
||||
|
||||
for states in state_list[1:]:
|
||||
for b, states in enumerate(state_list[1:], 1):
|
||||
for li, layer in enumerate(states):
|
||||
for si, s in enumerate(layer):
|
||||
ans[li][si] = torch.cat([ans[li][si], s], dim=1)
|
||||
ans[li][si].append(s)
|
||||
if b == batch_size - 1:
|
||||
ans[li][si] = torch.stack(ans[li][si], dim=1)
|
||||
# We will use unbind(dim=1) later in unstack_states()
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -190,11 +190,7 @@ class StreamingAudioSamples(object):
|
||||
"""
|
||||
ans = []
|
||||
|
||||
# Note: Either branch is fine. The purpose is to simulate streaming
|
||||
if False:
|
||||
num = torch.randint(2000, 5000, (len(self.samples),)).tolist()
|
||||
else:
|
||||
num = [1024] * len(self.samples)
|
||||
num = [1024] * len(self.samples)
|
||||
|
||||
for i in range(len(self.samples)):
|
||||
start = self.cur_indexes[i]
|
||||
@ -293,16 +289,17 @@ class StreamList(object):
|
||||
# has a shape (1, feature_dim)
|
||||
chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = stream.feature_frames[segment_length:]
|
||||
features = torch.cat(chunk, dim=0).unsqueeze(0)
|
||||
features = torch.cat(chunk, dim=0)
|
||||
feature_list.append(features)
|
||||
stream_list.append(stream)
|
||||
elif stream.done and len(stream.feature_frames) > 0:
|
||||
chunk = stream.feature_frames[:chunk_length]
|
||||
stream.feature_frames = []
|
||||
features = torch.cat(chunk, dim=0).unsqueeze(0)
|
||||
features = torch.cat(chunk, dim=0)
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, chunk_length - features.size(1)),
|
||||
(0, 0, 0, chunk_length - features.size(0)),
|
||||
mode="constant",
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
feature_list.append(features)
|
||||
@ -311,7 +308,7 @@ class StreamList(object):
|
||||
if len(feature_list) == 0:
|
||||
return None, None
|
||||
|
||||
features = torch.cat(feature_list, dim=0)
|
||||
features = torch.stack(feature_list, dim=0)
|
||||
return features, stream_list
|
||||
|
||||
|
||||
@ -346,10 +343,10 @@ def greedy_search(
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
).unsqueeze(1)
|
||||
# decoder_out is of shape (N, 1, decoder_out_dim)
|
||||
).squeeze(1)
|
||||
# decoder_out is of shape (N, decoder_out_dim)
|
||||
else:
|
||||
decoder_out = torch.cat(
|
||||
decoder_out = torch.stack(
|
||||
[stream.decoder_out for stream in streams],
|
||||
dim=0,
|
||||
)
|
||||
@ -358,13 +355,12 @@ def greedy_search(
|
||||
|
||||
T = encoder_out.size(1)
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
current_encoder_out = encoder_out[:, t]
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
@ -380,9 +376,9 @@ def greedy_search(
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input, need_pad=False
|
||||
).unsqueeze(1)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
|
||||
1
|
||||
)
|
||||
|
||||
for k, s in enumerate(streams):
|
||||
logging.info(
|
||||
@ -392,7 +388,7 @@ def greedy_search(
|
||||
decoder_out_list = decoder_out.unbind(dim=0)
|
||||
|
||||
for i, d in enumerate(decoder_out_list):
|
||||
streams[i].decoder_out = d.unsqueeze(0)
|
||||
streams[i].decoder_out = d
|
||||
|
||||
|
||||
def process_features(
|
||||
@ -424,6 +420,10 @@ def process_features(
|
||||
fill_value=features.size(1),
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Caution: It has a limitation as it assumes that
|
||||
# if one of the stream has an empty state, then all other
|
||||
# streams also have empty states.
|
||||
if streams[0].states is None:
|
||||
states = None
|
||||
else:
|
||||
|
@ -60,6 +60,7 @@ class FeatureExtractionStream(object):
|
||||
|
||||
# For the RNN-T decoder, it contains the decoder output
|
||||
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
||||
# Its shape is (decoder_out_dim,)
|
||||
self.decoder_out: Optional[torch.Tensor] = None
|
||||
|
||||
# After calling `self.input_finished()`, we set this flag to True
|
||||
|
Loading…
x
Reference in New Issue
Block a user