No more T < S after frame_reducer (#875)

* No more T < S after frame_reducer

* Fix for style check

* Adjust the permissions

* Add support for inference to frame_reducer

* Fix for flake8 check

---------

Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
Yifan Yang 2023-02-06 12:17:45 +08:00 committed by GitHub
parent bf5f0342a2
commit caf23546ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 65 additions and 22 deletions

View File

View File

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
# Zengwei Yao)
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -18,7 +19,7 @@
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -44,6 +45,7 @@ class FrameReducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
ctc_output: torch.Tensor,
y_lens: Optional[torch.Tensor] = None,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -55,6 +57,9 @@ class FrameReducer(nn.Module):
`x` before padding.
ctc_output:
The CTC output with shape [N, T, vocab_size].
y_lens:
A tensor of shape (batch_size,) containing the number of frames in
`y` before padding.
blank_id:
The blank id of ctc_output.
Returns:
@ -64,15 +69,45 @@ class FrameReducer(nn.Module):
A tensor of shape (batch_size,) containing the number of frames in
`out` before padding.
"""
N, T, C = x.size()
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
if y_lens is not None:
# Limit the maximum number of reduced frames
limit_lens = T - y_lens
max_limit_len = limit_lens.max().int()
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
).indices
T = (
torch.arange(max_limit_len)
.expand_as(
fake_limit_indexes,
)
.to(device=x.device)
)
T = torch.remainder(T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
limit_mask = torch.full_like(
non_blank_mask,
False,
device=x.device,
).scatter_(1, limit_indexes, True)
non_blank_mask = non_blank_mask | ~limit_mask
out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens
pad_lens_list = (
torch.full_like(
out_lens,
max_len.item(),
device=x.device,
)
- out_lens
)
max_pad_len = pad_lens_list.max()
out = F.pad(x, (0, 0, 0, max_pad_len))
@ -82,26 +117,30 @@ class FrameReducer(nn.Module):
out = out[total_valid_mask].reshape(N, -1, C)
return out.to(device=x.device), out_lens.to(device=x.device)
return out, out_lens
if __name__ == "__main__":
import time
from torch.nn.utils.rnn import pad_sequence
test_times = 10000
device = "cuda:0"
frame_reducer = FrameReducer()
# non zero case
x = torch.ones(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32))
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.log(
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
)
avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
@ -109,14 +148,17 @@ if __name__ == "__main__":
print(avg_time / test_times)
# all zero case
x = torch.zeros(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32)
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)

View File

View File

@ -131,6 +131,10 @@ class Transducer(nn.Module):
# compute ctc log-probs
ctc_output = self.ctc_output(encoder_out)
# y_lens
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
# blank skip
blank_id = self.decoder.blank_id
@ -146,16 +150,14 @@ class Transducer(nn.Module):
encoder_out,
x_lens,
ctc_output,
y_lens,
blank_id,
)
else:
encoder_out_fr = encoder_out
x_lens_fr = x_lens
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
# sos_y
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
--full-libri 1 \
--max-duration 550
--max-duration 750
"""