mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Use torch.stack() to replace torch.cat()
This commit is contained in:
parent
4cef2728cd
commit
747339a6c1
@ -32,13 +32,16 @@ class Joiner(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
Output from the encoder. Its shape is (N, T, s_range, C) for
|
||||||
|
training and (N, C) for streaming decoding.
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
Output from the decoder. Its shape is (N, T, s_range, C) for
|
||||||
|
training and (N, C) for streaming decoding.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
assert encoder_out.ndim == decoder_out.ndim
|
||||||
|
assert encoder_out.ndim in (2, 4)
|
||||||
assert encoder_out.shape == decoder_out.shape
|
assert encoder_out.shape == decoder_out.shape
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
@ -51,8 +51,9 @@ def unstack_states(
|
|||||||
for li, layer in enumerate(states):
|
for li, layer in enumerate(states):
|
||||||
for s in layer:
|
for s in layer:
|
||||||
s_list = s.unbind(dim=1)
|
s_list = s.unbind(dim=1)
|
||||||
|
# We will use stack(dim=1) later in stack_states()
|
||||||
for bi, b in enumerate(ans):
|
for bi, b in enumerate(ans):
|
||||||
b[li].append(s_list[bi].unsqueeze(dim=1))
|
b[li].append(s_list[bi])
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -75,15 +76,23 @@ def stack_states(
|
|||||||
See the input argument of :func:`unstack_states` for the meaning
|
See the input argument of :func:`unstack_states` for the meaning
|
||||||
of the returned tensor.
|
of the returned tensor.
|
||||||
"""
|
"""
|
||||||
|
batch_size = len(state_list)
|
||||||
ans = []
|
ans = []
|
||||||
for layer in state_list[0]:
|
for layer in state_list[0]:
|
||||||
# layer is a list of tensors
|
# layer is a list of tensors
|
||||||
ans.append([s for s in layer])
|
if batch_size > 1:
|
||||||
|
ans.append([[s] for s in layer])
|
||||||
|
# Note: We will stack ans[layer][s][] later to get ans[layer][s]
|
||||||
|
else:
|
||||||
|
ans.append([s.unsqueeze(1) for s in layer])
|
||||||
|
|
||||||
for states in state_list[1:]:
|
for b, states in enumerate(state_list[1:], 1):
|
||||||
for li, layer in enumerate(states):
|
for li, layer in enumerate(states):
|
||||||
for si, s in enumerate(layer):
|
for si, s in enumerate(layer):
|
||||||
ans[li][si] = torch.cat([ans[li][si], s], dim=1)
|
ans[li][si].append(s)
|
||||||
|
if b == batch_size - 1:
|
||||||
|
ans[li][si] = torch.stack(ans[li][si], dim=1)
|
||||||
|
# We will use unbind(dim=1) later in unstack_states()
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@ -190,11 +190,7 @@ class StreamingAudioSamples(object):
|
|||||||
"""
|
"""
|
||||||
ans = []
|
ans = []
|
||||||
|
|
||||||
# Note: Either branch is fine. The purpose is to simulate streaming
|
num = [1024] * len(self.samples)
|
||||||
if False:
|
|
||||||
num = torch.randint(2000, 5000, (len(self.samples),)).tolist()
|
|
||||||
else:
|
|
||||||
num = [1024] * len(self.samples)
|
|
||||||
|
|
||||||
for i in range(len(self.samples)):
|
for i in range(len(self.samples)):
|
||||||
start = self.cur_indexes[i]
|
start = self.cur_indexes[i]
|
||||||
@ -293,16 +289,17 @@ class StreamList(object):
|
|||||||
# has a shape (1, feature_dim)
|
# has a shape (1, feature_dim)
|
||||||
chunk = stream.feature_frames[:chunk_length]
|
chunk = stream.feature_frames[:chunk_length]
|
||||||
stream.feature_frames = stream.feature_frames[segment_length:]
|
stream.feature_frames = stream.feature_frames[segment_length:]
|
||||||
features = torch.cat(chunk, dim=0).unsqueeze(0)
|
features = torch.cat(chunk, dim=0)
|
||||||
feature_list.append(features)
|
feature_list.append(features)
|
||||||
stream_list.append(stream)
|
stream_list.append(stream)
|
||||||
elif stream.done and len(stream.feature_frames) > 0:
|
elif stream.done and len(stream.feature_frames) > 0:
|
||||||
chunk = stream.feature_frames[:chunk_length]
|
chunk = stream.feature_frames[:chunk_length]
|
||||||
stream.feature_frames = []
|
stream.feature_frames = []
|
||||||
features = torch.cat(chunk, dim=0).unsqueeze(0)
|
features = torch.cat(chunk, dim=0)
|
||||||
features = torch.nn.functional.pad(
|
features = torch.nn.functional.pad(
|
||||||
features,
|
features,
|
||||||
(0, 0, 0, chunk_length - features.size(1)),
|
(0, 0, 0, chunk_length - features.size(0)),
|
||||||
|
mode="constant",
|
||||||
value=LOG_EPSILON,
|
value=LOG_EPSILON,
|
||||||
)
|
)
|
||||||
feature_list.append(features)
|
feature_list.append(features)
|
||||||
@ -311,7 +308,7 @@ class StreamList(object):
|
|||||||
if len(feature_list) == 0:
|
if len(feature_list) == 0:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
features = torch.cat(feature_list, dim=0)
|
features = torch.stack(feature_list, dim=0)
|
||||||
return features, stream_list
|
return features, stream_list
|
||||||
|
|
||||||
|
|
||||||
@ -346,10 +343,10 @@ def greedy_search(
|
|||||||
decoder_out = model.decoder(
|
decoder_out = model.decoder(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
need_pad=False,
|
need_pad=False,
|
||||||
).unsqueeze(1)
|
).squeeze(1)
|
||||||
# decoder_out is of shape (N, 1, decoder_out_dim)
|
# decoder_out is of shape (N, decoder_out_dim)
|
||||||
else:
|
else:
|
||||||
decoder_out = torch.cat(
|
decoder_out = torch.stack(
|
||||||
[stream.decoder_out for stream in streams],
|
[stream.decoder_out for stream in streams],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
@ -358,13 +355,12 @@ def greedy_search(
|
|||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
|
current_encoder_out = encoder_out[:, t]
|
||||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||||
|
|
||||||
logits = model.joiner(current_encoder_out, decoder_out)
|
logits = model.joiner(current_encoder_out, decoder_out)
|
||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
# logits'shape (batch_size, vocab_size)
|
||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
|
||||||
assert logits.ndim == 2, logits.shape
|
assert logits.ndim == 2, logits.shape
|
||||||
y = logits.argmax(dim=1).tolist()
|
y = logits.argmax(dim=1).tolist()
|
||||||
emitted = False
|
emitted = False
|
||||||
@ -380,9 +376,9 @@ def greedy_search(
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
decoder_out = model.decoder(
|
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(
|
||||||
decoder_input, need_pad=False
|
1
|
||||||
).unsqueeze(1)
|
)
|
||||||
|
|
||||||
for k, s in enumerate(streams):
|
for k, s in enumerate(streams):
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -392,7 +388,7 @@ def greedy_search(
|
|||||||
decoder_out_list = decoder_out.unbind(dim=0)
|
decoder_out_list = decoder_out.unbind(dim=0)
|
||||||
|
|
||||||
for i, d in enumerate(decoder_out_list):
|
for i, d in enumerate(decoder_out_list):
|
||||||
streams[i].decoder_out = d.unsqueeze(0)
|
streams[i].decoder_out = d
|
||||||
|
|
||||||
|
|
||||||
def process_features(
|
def process_features(
|
||||||
@ -424,6 +420,10 @@ def process_features(
|
|||||||
fill_value=features.size(1),
|
fill_value=features.size(1),
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Caution: It has a limitation as it assumes that
|
||||||
|
# if one of the stream has an empty state, then all other
|
||||||
|
# streams also have empty states.
|
||||||
if streams[0].states is None:
|
if streams[0].states is None:
|
||||||
states = None
|
states = None
|
||||||
else:
|
else:
|
||||||
|
@ -60,6 +60,7 @@ class FeatureExtractionStream(object):
|
|||||||
|
|
||||||
# For the RNN-T decoder, it contains the decoder output
|
# For the RNN-T decoder, it contains the decoder output
|
||||||
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
# corresponding to the decoder input self.hyp.ys[-context_size:]
|
||||||
|
# Its shape is (decoder_out_dim,)
|
||||||
self.decoder_out: Optional[torch.Tensor] = None
|
self.decoder_out: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# After calling `self.input_finished()`, we set this flag to True
|
# After calling `self.input_finished()`, we set this flag to True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user