fix bugs, no grad, and num_proccessed_frames

This commit is contained in:
yaozengwei 2022-06-10 17:17:03 +08:00
parent 507d7c13f4
commit 026fb22076
2 changed files with 40 additions and 35 deletions

View File

@ -1564,29 +1564,31 @@ class EmformerEncoder(nn.Module):
)
# calcualte padding mask to mask out initial zero caches
# chunk_mask = make_pad_mask(output_lengths).to(x.device)
# memory_mask = (
# (past_lens // self.chunk_length).view(x.size(1), 1)
# <= torch.arange(self.memory_size, device=x.device).expand(
# x.size(1), self.memory_size
# )
# ).flip(1)
# left_context_mask = (
# past_lens.view(x.size(1), 1)
# <= torch.arange(self.left_context_length, device=x.device).expand(
# x.size(1), self.left_context_length
# )
# ).flip(1)
# right_context_mask = torch.zeros(
# x.size(1),
# self.right_context_length,
# dtype=torch.bool,
# device=x.device,
# )
# padding_mask = torch.cat(
# [memory_mask, left_context_mask, right_context_mask, chunk_mask],
# dim=1,
# )
chunk_mask = make_pad_mask(output_lengths).to(x.device)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
memory_mask = (
(num_processed_frames // self.chunk_length).view(x.size(1), 1)
<= torch.arange(self.memory_size, device=x.device).expand(
x.size(1), self.memory_size
)
).flip(1)
left_context_mask = (
num_processed_frames.view(x.size(1), 1)
<= torch.arange(self.left_context_length, device=x.device).expand(
x.size(1), self.left_context_length
)
).flip(1)
right_context_mask = torch.zeros(
x.size(1),
self.right_context_length,
dtype=torch.bool,
device=x.device,
)
padding_mask = torch.cat(
[memory_mask, right_context_mask, left_context_mask, chunk_mask],
dim=1,
)
output = utterance
output_attn_caches: List[List[torch.Tensor]] = []
@ -1602,7 +1604,7 @@ class EmformerEncoder(nn.Module):
output,
right_context,
memory,
# padding_mask=padding_mask,
padding_mask=padding_mask,
attn_cache=attn_caches[layer_idx],
conv_cache=conv_caches[layer_idx],
)
@ -1765,6 +1767,8 @@ class Emformer(EncoderInterface):
x_lens -= 2
assert x.size(0) == x_lens.max().item()
num_processed_frames = num_processed_frames >> 2
output, output_lengths, output_states = self.encoder.infer(
x, x_lens, num_processed_frames, states
)

View File

@ -17,6 +17,7 @@
# limitations under the License.
import argparse
import copy
import logging
import math
import warnings
@ -261,11 +262,14 @@ def decode_one_chunk(
num_processed_frames_list = []
for stream in streams:
# We should first get `stream.num_processed_frames`
# before calling `stream.get_feature_chunk()`
# since `stream.num_processed_frames` would be updated
num_processed_frames_list.append(stream.num_processed_frames)
feature, feature_len = stream.get_feature_chunk()
feature_list.append(feature)
feature_len_list.append(feature_len)
state_list.append(stream.states)
num_processed_frames_list.append(stream.num_processed_frames)
features = pad_sequence(
feature_list, batch_first=True, padding_value=LOG_EPSILON
@ -288,7 +292,6 @@ def decode_one_chunk(
mode="constant",
value=LOG_EPSILON,
)
# print(features.shape)
# stack states of all streams
states = stack_states(state_list)
@ -300,12 +303,6 @@ def decode_one_chunk(
)
encoder_out = model.joiner.encoder_proj(encoder_out)
# update cached states of each stream
state_list = unstack_states(states)
assert len(streams) == len(state_list)
for i, s in enumerate(state_list):
streams[i].states = s
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
@ -325,6 +322,11 @@ def decode_one_chunk(
f"Unsupported decoding method: {params.decoding_method}"
)
# update cached states of each stream
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
return finished_streams
@ -399,9 +401,7 @@ def decode_dataset(
sp.decode(streams[i].decoding_result()).split(),
)
)
print(decode_results[-1])
del streams[i]
# print("delete", i, len(streams))
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
@ -470,9 +470,10 @@ def save_results(
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s) @ torch.no_grad()
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)