No more T < S after frame_reducer (#875)

* No more T < S after frame_reducer

* Fix for style check

* Adjust the permissions

* Add support for inference to frame_reducer

* Fix for flake8 check

---------

Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
Yifan Yang 2023-02-06 12:17:45 +08:00 committed by GitHub
parent bf5f0342a2
commit caf23546ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 65 additions and 22 deletions

View File

View File

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, # Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
# Zengwei Yao) # Zengwei Yao,
# Wei Kang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -18,7 +19,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Tuple, Union from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -44,6 +45,7 @@ class FrameReducer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
ctc_output: torch.Tensor, ctc_output: torch.Tensor,
y_lens: Optional[torch.Tensor] = None,
blank_id: int = 0, blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
@ -55,6 +57,9 @@ class FrameReducer(nn.Module):
`x` before padding. `x` before padding.
ctc_output: ctc_output:
The CTC output with shape [N, T, vocab_size]. The CTC output with shape [N, T, vocab_size].
y_lens:
A tensor of shape (batch_size,) containing the number of frames in
`y` before padding.
blank_id: blank_id:
The blank id of ctc_output. The blank id of ctc_output.
Returns: Returns:
@ -64,15 +69,45 @@ class FrameReducer(nn.Module):
A tensor of shape (batch_size,) containing the number of frames in A tensor of shape (batch_size,) containing the number of frames in
`out` before padding. `out` before padding.
""" """
N, T, C = x.size() N, T, C = x.size()
padding_mask = make_pad_mask(x_lens) padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
if y_lens is not None:
# Limit the maximum number of reduced frames
limit_lens = T - y_lens
max_limit_len = limit_lens.max().int()
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
).indices
T = (
torch.arange(max_limit_len)
.expand_as(
fake_limit_indexes,
)
.to(device=x.device)
)
T = torch.remainder(T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
limit_mask = torch.full_like(
non_blank_mask,
False,
device=x.device,
).scatter_(1, limit_indexes, True)
non_blank_mask = non_blank_mask | ~limit_mask
out_lens = non_blank_mask.sum(dim=1) out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max() max_len = out_lens.max()
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens pad_lens_list = (
torch.full_like(
out_lens,
max_len.item(),
device=x.device,
)
- out_lens
)
max_pad_len = pad_lens_list.max() max_pad_len = pad_lens_list.max()
out = F.pad(x, (0, 0, 0, max_pad_len)) out = F.pad(x, (0, 0, 0, max_pad_len))
@ -82,26 +117,30 @@ class FrameReducer(nn.Module):
out = out[total_valid_mask].reshape(N, -1, C) out = out[total_valid_mask].reshape(N, -1, C)
return out.to(device=x.device), out_lens.to(device=x.device) return out, out_lens
if __name__ == "__main__": if __name__ == "__main__":
import time import time
from torch.nn.utils.rnn import pad_sequence
test_times = 10000 test_times = 10000
device = "cuda:0"
frame_reducer = FrameReducer() frame_reducer = FrameReducer()
# non zero case # non zero case
x = torch.ones(15, 498, 384, dtype=torch.float32) x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64) x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32)) y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) ctc_output = torch.log(
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
)
avg_time = 0 avg_time = 0
for i in range(test_times): for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time() delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time delta_time = time.time() - delta_time
avg_time += delta_time avg_time += delta_time
print(x_fr.shape) print(x_fr.shape)
@ -109,14 +148,17 @@ if __name__ == "__main__":
print(avg_time / test_times) print(avg_time / test_times)
# all zero case # all zero case
x = torch.zeros(15, 498, 384, dtype=torch.float32) x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64) x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32) y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
avg_time = 0 avg_time = 0
for i in range(test_times): for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time() delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output) x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time delta_time = time.time() - delta_time
avg_time += delta_time avg_time += delta_time
print(x_fr.shape) print(x_fr.shape)

View File

View File

@ -131,6 +131,10 @@ class Transducer(nn.Module):
# compute ctc log-probs # compute ctc log-probs
ctc_output = self.ctc_output(encoder_out) ctc_output = self.ctc_output(encoder_out)
# y_lens
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
# blank skip # blank skip
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
@ -146,16 +150,14 @@ class Transducer(nn.Module):
encoder_out, encoder_out,
x_lens, x_lens,
ctc_output, ctc_output,
y_lens,
blank_id, blank_id,
) )
else: else:
encoder_out_fr = encoder_out encoder_out_fr = encoder_out
x_lens_fr = x_lens x_lens_fr = x_lens
# Now for the decoder, i.e., the prediction network # sos_y
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS. # sos_y_padded: [B, S + 1], start with SOS.

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang, # Wei Kang,
# Mingshuang Luo, # Mingshuang Luo,
@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir pruned_transducer_stateless7_ctc_bs/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 750
""" """