mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
No more T < S after frame_reducer (#875)
* No more T < S after frame_reducer * Fix for style check * Adjust the permissions * Add support for inference to frame_reducer * Fix for flake8 check --------- Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
parent
bf5f0342a2
commit
caf23546ed
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py
Normal file → Executable file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py
Normal file → Executable file
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
||||
# Zengwei Yao)
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
||||
# Zengwei Yao,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -18,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -44,6 +45,7 @@ class FrameReducer(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
ctc_output: torch.Tensor,
|
||||
y_lens: Optional[torch.Tensor] = None,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -55,6 +57,9 @@ class FrameReducer(nn.Module):
|
||||
`x` before padding.
|
||||
ctc_output:
|
||||
The CTC output with shape [N, T, vocab_size].
|
||||
y_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`y` before padding.
|
||||
blank_id:
|
||||
The blank id of ctc_output.
|
||||
Returns:
|
||||
@ -64,15 +69,45 @@ class FrameReducer(nn.Module):
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`out` before padding.
|
||||
"""
|
||||
|
||||
N, T, C = x.size()
|
||||
|
||||
padding_mask = make_pad_mask(x_lens)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||
|
||||
if y_lens is not None:
|
||||
# Limit the maximum number of reduced frames
|
||||
limit_lens = T - y_lens
|
||||
max_limit_len = limit_lens.max().int()
|
||||
fake_limit_indexes = torch.topk(
|
||||
ctc_output[:, :, blank_id], max_limit_len
|
||||
).indices
|
||||
T = (
|
||||
torch.arange(max_limit_len)
|
||||
.expand_as(
|
||||
fake_limit_indexes,
|
||||
)
|
||||
.to(device=x.device)
|
||||
)
|
||||
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
||||
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
|
||||
limit_mask = torch.full_like(
|
||||
non_blank_mask,
|
||||
False,
|
||||
device=x.device,
|
||||
).scatter_(1, limit_indexes, True)
|
||||
|
||||
non_blank_mask = non_blank_mask | ~limit_mask
|
||||
|
||||
out_lens = non_blank_mask.sum(dim=1)
|
||||
max_len = out_lens.max()
|
||||
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens
|
||||
pad_lens_list = (
|
||||
torch.full_like(
|
||||
out_lens,
|
||||
max_len.item(),
|
||||
device=x.device,
|
||||
)
|
||||
- out_lens
|
||||
)
|
||||
max_pad_len = pad_lens_list.max()
|
||||
|
||||
out = F.pad(x, (0, 0, 0, max_pad_len))
|
||||
@ -82,26 +117,30 @@ class FrameReducer(nn.Module):
|
||||
|
||||
out = out[total_valid_mask].reshape(N, -1, C)
|
||||
|
||||
return out.to(device=x.device), out_lens.to(device=x.device)
|
||||
return out, out_lens
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
test_times = 10000
|
||||
device = "cuda:0"
|
||||
frame_reducer = FrameReducer()
|
||||
|
||||
# non zero case
|
||||
x = torch.ones(15, 498, 384, dtype=torch.float32)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32))
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||
ctc_output = torch.log(
|
||||
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
|
||||
)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(test_times):
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time()
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time() - delta_time
|
||||
avg_time += delta_time
|
||||
print(x_fr.shape)
|
||||
@ -109,14 +148,17 @@ if __name__ == "__main__":
|
||||
print(avg_time / test_times)
|
||||
|
||||
# all zero case
|
||||
x = torch.zeros(15, 498, 384, dtype=torch.float32)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32)
|
||||
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
|
||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
|
||||
|
||||
avg_time = 0
|
||||
for i in range(test_times):
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time()
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||
torch.cuda.synchronize(device=x.device)
|
||||
delta_time = time.time() - delta_time
|
||||
avg_time += delta_time
|
||||
print(x_fr.shape)
|
||||
|
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
Executable file → Normal file
10
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
Executable file → Normal file
10
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
Executable file → Normal file
@ -131,6 +131,10 @@ class Transducer(nn.Module):
|
||||
# compute ctc log-probs
|
||||
ctc_output = self.ctc_output(encoder_out)
|
||||
|
||||
# y_lens
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
# blank skip
|
||||
blank_id = self.decoder.blank_id
|
||||
|
||||
@ -146,16 +150,14 @@ class Transducer(nn.Module):
|
||||
encoder_out,
|
||||
x_lens,
|
||||
ctc_output,
|
||||
y_lens,
|
||||
blank_id,
|
||||
)
|
||||
else:
|
||||
encoder_out_fr = encoder_out
|
||||
x_lens_fr = x_lens
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
# sos_y
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
|
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
Normal file → Executable file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
Normal file → Executable file
@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,
|
||||
@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
--max-duration 750
|
||||
"""
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user