mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
fix bugs, no grad, and num_proccessed_frames
This commit is contained in:
parent
507d7c13f4
commit
026fb22076
@ -1564,29 +1564,31 @@ class EmformerEncoder(nn.Module):
|
||||
)
|
||||
|
||||
# calcualte padding mask to mask out initial zero caches
|
||||
# chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
# memory_mask = (
|
||||
# (past_lens // self.chunk_length).view(x.size(1), 1)
|
||||
# <= torch.arange(self.memory_size, device=x.device).expand(
|
||||
# x.size(1), self.memory_size
|
||||
# )
|
||||
# ).flip(1)
|
||||
# left_context_mask = (
|
||||
# past_lens.view(x.size(1), 1)
|
||||
# <= torch.arange(self.left_context_length, device=x.device).expand(
|
||||
# x.size(1), self.left_context_length
|
||||
# )
|
||||
# ).flip(1)
|
||||
# right_context_mask = torch.zeros(
|
||||
# x.size(1),
|
||||
# self.right_context_length,
|
||||
# dtype=torch.bool,
|
||||
# device=x.device,
|
||||
# )
|
||||
# padding_mask = torch.cat(
|
||||
# [memory_mask, left_context_mask, right_context_mask, chunk_mask],
|
||||
# dim=1,
|
||||
# )
|
||||
chunk_mask = make_pad_mask(output_lengths).to(x.device)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
memory_mask = (
|
||||
(num_processed_frames // self.chunk_length).view(x.size(1), 1)
|
||||
<= torch.arange(self.memory_size, device=x.device).expand(
|
||||
x.size(1), self.memory_size
|
||||
)
|
||||
).flip(1)
|
||||
left_context_mask = (
|
||||
num_processed_frames.view(x.size(1), 1)
|
||||
<= torch.arange(self.left_context_length, device=x.device).expand(
|
||||
x.size(1), self.left_context_length
|
||||
)
|
||||
).flip(1)
|
||||
right_context_mask = torch.zeros(
|
||||
x.size(1),
|
||||
self.right_context_length,
|
||||
dtype=torch.bool,
|
||||
device=x.device,
|
||||
)
|
||||
padding_mask = torch.cat(
|
||||
[memory_mask, right_context_mask, left_context_mask, chunk_mask],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
output = utterance
|
||||
output_attn_caches: List[List[torch.Tensor]] = []
|
||||
@ -1602,7 +1604,7 @@ class EmformerEncoder(nn.Module):
|
||||
output,
|
||||
right_context,
|
||||
memory,
|
||||
# padding_mask=padding_mask,
|
||||
padding_mask=padding_mask,
|
||||
attn_cache=attn_caches[layer_idx],
|
||||
conv_cache=conv_caches[layer_idx],
|
||||
)
|
||||
@ -1765,6 +1767,8 @@ class Emformer(EncoderInterface):
|
||||
x_lens -= 2
|
||||
assert x.size(0) == x_lens.max().item()
|
||||
|
||||
num_processed_frames = num_processed_frames >> 2
|
||||
|
||||
output, output_lengths, output_states = self.encoder.infer(
|
||||
x, x_lens, num_processed_frames, states
|
||||
)
|
||||
|
@ -17,6 +17,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
@ -261,11 +262,14 @@ def decode_one_chunk(
|
||||
num_processed_frames_list = []
|
||||
|
||||
for stream in streams:
|
||||
# We should first get `stream.num_processed_frames`
|
||||
# before calling `stream.get_feature_chunk()`
|
||||
# since `stream.num_processed_frames` would be updated
|
||||
num_processed_frames_list.append(stream.num_processed_frames)
|
||||
feature, feature_len = stream.get_feature_chunk()
|
||||
feature_list.append(feature)
|
||||
feature_len_list.append(feature_len)
|
||||
state_list.append(stream.states)
|
||||
num_processed_frames_list.append(stream.num_processed_frames)
|
||||
|
||||
features = pad_sequence(
|
||||
feature_list, batch_first=True, padding_value=LOG_EPSILON
|
||||
@ -288,7 +292,6 @@ def decode_one_chunk(
|
||||
mode="constant",
|
||||
value=LOG_EPSILON,
|
||||
)
|
||||
# print(features.shape)
|
||||
# stack states of all streams
|
||||
states = stack_states(state_list)
|
||||
|
||||
@ -300,12 +303,6 @@ def decode_one_chunk(
|
||||
)
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
# update cached states of each stream
|
||||
state_list = unstack_states(states)
|
||||
assert len(streams) == len(state_list)
|
||||
for i, s in enumerate(state_list):
|
||||
streams[i].states = s
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(
|
||||
model=model,
|
||||
@ -325,6 +322,11 @@ def decode_one_chunk(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
# update cached states of each stream
|
||||
state_list = unstack_states(states)
|
||||
for i, s in enumerate(state_list):
|
||||
streams[i].states = s
|
||||
|
||||
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
|
||||
return finished_streams
|
||||
|
||||
@ -399,9 +401,7 @@ def decode_dataset(
|
||||
sp.decode(streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
print(decode_results[-1])
|
||||
del streams[i]
|
||||
# print("delete", i, len(streams))
|
||||
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
@ -470,9 +470,10 @@ def save_results(
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s) @ torch.no_grad()
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
|
Loading…
x
Reference in New Issue
Block a user