Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-07-28 15:52:51 +08:00
parent 8c98599ded
commit 49aaaf8021
5 changed files with 64 additions and 46 deletions

View File

@ -155,7 +155,8 @@ class Conformer(EncoderInterface):
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths)
@ -787,6 +788,14 @@ class RelPositionalEncoding(torch.nn.Module):
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
if torch.jit.is_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than
# 10k frames.
#
# TODO(fangjun): Use torch.jit.script() for this module
max_len = 10000
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
@ -1006,34 +1015,20 @@ class RelPositionMultiheadAttention(nn.Module):
(batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
if not torch.jit.is_tracing():
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
if torch.jit.is_scripting() or torch.jit.is_tracing():
x = x.contiguous()
b = x.size(0)
h = x.size(1)
t = x.size(2)
c = x.size(3)
if torch.jit.is_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time1)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
bh = b * h
if False:
rows = torch.arange(start=t - 1, end=-1, step=-1).unsqueeze(-1)
cols = torch.arange(t)
indexes = rows + cols
# onnx does not support torch.tile
indexes = torch.tile(indexes, (bh, 1))
else:
rows = torch.arange(start=t - 1, end=-1, step=-1)
cols = torch.arange(t)
rows = torch.cat([rows] * bh).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, c)
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(b, h, t, t)
x = x.reshape(batch_size, num_heads, time1, time1)
return x
else:
# Note: TorchScript requires explicit arg for stride()
@ -1116,13 +1111,15 @@ class RelPositionMultiheadAttention(nn.Module):
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
if not torch.jit.is_tracing():
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
if not torch.jit.is_tracing():
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
@ -1235,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module):
src_len = k.size(0)
if key_padding_mask is not None:
if key_padding_mask is not None and not torch.jit.is_tracing():
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
@ -1246,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
if not torch.jit.is_tracing():
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
@ -1281,11 +1280,12 @@ class RelPositionMultiheadAttention(nn.Module):
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if not torch.jit.is_tracing():
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
@ -1344,7 +1344,14 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
if not torch.jit.is_tracing():
assert list(attn_output.size()) == [
bsz * num_heads,
tgt_len,
head_dim,
]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()

View File

@ -53,10 +53,9 @@ class Joiner(nn.Module):
Return a tensor of shape (N, T, s_range, C).
"""
if not torch.jit.is_scripting() or not torch.jit.is_tracing():
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(

View File

@ -152,7 +152,8 @@ class BasicNorm(torch.nn.Module):
self.register_buffer("eps", torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()

View File

@ -228,7 +228,18 @@ def main():
warmup = 1.0
encoder_filename = params.exp_dir / "encoder.onnx"
# encoder_model = torch.jit.script(model.encoder)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
encoder_model = model.encoder
torch.onnx.export(
encoder_model,
(x, x_lens, warmup),

View File

@ -76,8 +76,8 @@ def test_encoder(
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
x = torch.rand(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100])
x = torch.rand(5, 50, 80, dtype=torch.float32)
x_lens = torch.tensor([50, 50, 20, 30, 10])
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()}
encoder_out, encoder_out_lens = encoder_session.run(