mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
No more T < S after frame_reducer (#875)
* No more T < S after frame_reducer * Fix for style check * Adjust the permissions * Add support for inference to frame_reducer * Fix for flake8 check --------- Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
parent
bf5f0342a2
commit
caf23546ed
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/__init__.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py
Normal file → Executable file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py
Normal file → Executable file
@ -1,7 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
#
|
#
|
||||||
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
|
||||||
# Zengwei Yao)
|
# Zengwei Yao,
|
||||||
|
# Wei Kang)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -18,7 +19,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -44,6 +45,7 @@ class FrameReducer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
ctc_output: torch.Tensor,
|
ctc_output: torch.Tensor,
|
||||||
|
y_lens: Optional[torch.Tensor] = None,
|
||||||
blank_id: int = 0,
|
blank_id: int = 0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@ -55,6 +57,9 @@ class FrameReducer(nn.Module):
|
|||||||
`x` before padding.
|
`x` before padding.
|
||||||
ctc_output:
|
ctc_output:
|
||||||
The CTC output with shape [N, T, vocab_size].
|
The CTC output with shape [N, T, vocab_size].
|
||||||
|
y_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`y` before padding.
|
||||||
blank_id:
|
blank_id:
|
||||||
The blank id of ctc_output.
|
The blank id of ctc_output.
|
||||||
Returns:
|
Returns:
|
||||||
@ -64,15 +69,45 @@ class FrameReducer(nn.Module):
|
|||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
`out` before padding.
|
`out` before padding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
N, T, C = x.size()
|
N, T, C = x.size()
|
||||||
|
|
||||||
padding_mask = make_pad_mask(x_lens)
|
padding_mask = make_pad_mask(x_lens)
|
||||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||||
|
|
||||||
|
if y_lens is not None:
|
||||||
|
# Limit the maximum number of reduced frames
|
||||||
|
limit_lens = T - y_lens
|
||||||
|
max_limit_len = limit_lens.max().int()
|
||||||
|
fake_limit_indexes = torch.topk(
|
||||||
|
ctc_output[:, :, blank_id], max_limit_len
|
||||||
|
).indices
|
||||||
|
T = (
|
||||||
|
torch.arange(max_limit_len)
|
||||||
|
.expand_as(
|
||||||
|
fake_limit_indexes,
|
||||||
|
)
|
||||||
|
.to(device=x.device)
|
||||||
|
)
|
||||||
|
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
||||||
|
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
|
||||||
|
limit_mask = torch.full_like(
|
||||||
|
non_blank_mask,
|
||||||
|
False,
|
||||||
|
device=x.device,
|
||||||
|
).scatter_(1, limit_indexes, True)
|
||||||
|
|
||||||
|
non_blank_mask = non_blank_mask | ~limit_mask
|
||||||
|
|
||||||
out_lens = non_blank_mask.sum(dim=1)
|
out_lens = non_blank_mask.sum(dim=1)
|
||||||
max_len = out_lens.max()
|
max_len = out_lens.max()
|
||||||
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens
|
pad_lens_list = (
|
||||||
|
torch.full_like(
|
||||||
|
out_lens,
|
||||||
|
max_len.item(),
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
- out_lens
|
||||||
|
)
|
||||||
max_pad_len = pad_lens_list.max()
|
max_pad_len = pad_lens_list.max()
|
||||||
|
|
||||||
out = F.pad(x, (0, 0, 0, max_pad_len))
|
out = F.pad(x, (0, 0, 0, max_pad_len))
|
||||||
@ -82,26 +117,30 @@ class FrameReducer(nn.Module):
|
|||||||
|
|
||||||
out = out[total_valid_mask].reshape(N, -1, C)
|
out = out[total_valid_mask].reshape(N, -1, C)
|
||||||
|
|
||||||
return out.to(device=x.device), out_lens.to(device=x.device)
|
return out, out_lens
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import time
|
import time
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
test_times = 10000
|
test_times = 10000
|
||||||
|
device = "cuda:0"
|
||||||
frame_reducer = FrameReducer()
|
frame_reducer = FrameReducer()
|
||||||
|
|
||||||
# non zero case
|
# non zero case
|
||||||
x = torch.ones(15, 498, 384, dtype=torch.float32)
|
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
|
||||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||||
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32))
|
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
ctc_output = torch.log(
|
||||||
|
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
avg_time = 0
|
avg_time = 0
|
||||||
for i in range(test_times):
|
for i in range(test_times):
|
||||||
|
torch.cuda.synchronize(device=x.device)
|
||||||
delta_time = time.time()
|
delta_time = time.time()
|
||||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||||
|
torch.cuda.synchronize(device=x.device)
|
||||||
delta_time = time.time() - delta_time
|
delta_time = time.time() - delta_time
|
||||||
avg_time += delta_time
|
avg_time += delta_time
|
||||||
print(x_fr.shape)
|
print(x_fr.shape)
|
||||||
@ -109,14 +148,17 @@ if __name__ == "__main__":
|
|||||||
print(avg_time / test_times)
|
print(avg_time / test_times)
|
||||||
|
|
||||||
# all zero case
|
# all zero case
|
||||||
x = torch.zeros(15, 498, 384, dtype=torch.float32)
|
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
|
||||||
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
|
||||||
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32)
|
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
|
||||||
|
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
avg_time = 0
|
avg_time = 0
|
||||||
for i in range(test_times):
|
for i in range(test_times):
|
||||||
|
torch.cuda.synchronize(device=x.device)
|
||||||
delta_time = time.time()
|
delta_time = time.time()
|
||||||
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
|
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
|
||||||
|
torch.cuda.synchronize(device=x.device)
|
||||||
delta_time = time.time() - delta_time
|
delta_time = time.time() - delta_time
|
||||||
avg_time += delta_time
|
avg_time += delta_time
|
||||||
print(x_fr.shape)
|
print(x_fr.shape)
|
||||||
|
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
Executable file → Normal file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/lconv.py
Executable file → Normal file
10
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
Executable file → Normal file
10
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py
Executable file → Normal file
@ -131,6 +131,10 @@ class Transducer(nn.Module):
|
|||||||
# compute ctc log-probs
|
# compute ctc log-probs
|
||||||
ctc_output = self.ctc_output(encoder_out)
|
ctc_output = self.ctc_output(encoder_out)
|
||||||
|
|
||||||
|
# y_lens
|
||||||
|
row_splits = y.shape.row_splits(1)
|
||||||
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
# blank skip
|
# blank skip
|
||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
|
|
||||||
@ -146,16 +150,14 @@ class Transducer(nn.Module):
|
|||||||
encoder_out,
|
encoder_out,
|
||||||
x_lens,
|
x_lens,
|
||||||
ctc_output,
|
ctc_output,
|
||||||
|
y_lens,
|
||||||
blank_id,
|
blank_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_out_fr = encoder_out
|
encoder_out_fr = encoder_out
|
||||||
x_lens_fr = x_lens
|
x_lens_fr = x_lens
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# sos_y
|
||||||
row_splits = y.shape.row_splits(1)
|
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
|
||||||
|
|
||||||
sos_y = add_sos(y, sos_id=blank_id)
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
# sos_y_padded: [B, S + 1], start with SOS.
|
# sos_y_padded: [B, S + 1], start with SOS.
|
||||||
|
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
Normal file → Executable file
0
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py
Normal file → Executable file
@ -1,4 +1,3 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang,
|
# Wei Kang,
|
||||||
# Mingshuang Luo,
|
# Mingshuang Luo,
|
||||||
@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
|
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 750
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user