mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Modify make_pad_mask to support TensorRT (#943)
* Modify make_pad_mask to support TensorRT * Fix for test
This commit is contained in:
parent
9ddd811925
commit
cad6735e07
@ -432,11 +432,11 @@ def test_layernorm_lstm_forward(device="cpu"):
|
|||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_with_projection_forward(device="cpu"):
|
def test_layernorm_lstm_with_projection_forward(device="cpu"):
|
||||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
input_size = 40 # torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
|
hidden_size = 40 # torch.randint(low=10, high=100, size=(1,)).item()
|
||||||
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
|
proj_size = 20 # torch.randint(low=2, high=hidden_size, size=(1,)).item()
|
||||||
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
|
num_layers = 12 # torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
bias = True # torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
|
||||||
self_lstm = LayerNormLSTM(
|
self_lstm = LayerNormLSTM(
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
|
@ -1095,10 +1095,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|||||||
assert lengths.ndim == 1, lengths.ndim
|
assert lengths.ndim == 1, lengths.ndim
|
||||||
max_len = max(max_len, lengths.max())
|
max_len = max(max_len, lengths.max())
|
||||||
n = lengths.size(0)
|
n = lengths.size(0)
|
||||||
|
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||||
|
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||||
|
|
||||||
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
return expaned_lengths >= lengths.unsqueeze(-1)
|
||||||
|
|
||||||
return expaned_lengths >= lengths.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
||||||
|
Loading…
x
Reference in New Issue
Block a user