Minor fixes.
This commit is contained in:
parent
8c98599ded
commit
49aaaf8021
@ -155,7 +155,8 @@ class Conformer(EncoderInterface):
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
@ -787,6 +788,14 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
if torch.jit.is_tracing():
|
||||
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
|
||||
# It assumes that the maximum input won't have more than
|
||||
# 10k frames.
|
||||
#
|
||||
# TODO(fangjun): Use torch.jit.script() for this module
|
||||
max_len = 10000
|
||||
|
||||
self.d_model = d_model
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
@ -1006,34 +1015,20 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
|
||||
time2 = time1 + left_context
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
if not torch.jit.is_tracing():
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
x = x.contiguous()
|
||||
b = x.size(0)
|
||||
h = x.size(1)
|
||||
t = x.size(2)
|
||||
c = x.size(3)
|
||||
if torch.jit.is_tracing():
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(time1)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
|
||||
bh = b * h
|
||||
|
||||
if False:
|
||||
rows = torch.arange(start=t - 1, end=-1, step=-1).unsqueeze(-1)
|
||||
cols = torch.arange(t)
|
||||
indexes = rows + cols
|
||||
# onnx does not support torch.tile
|
||||
indexes = torch.tile(indexes, (bh, 1))
|
||||
else:
|
||||
rows = torch.arange(start=t - 1, end=-1, step=-1)
|
||||
cols = torch.arange(t)
|
||||
rows = torch.cat([rows] * bh).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
|
||||
x = x.reshape(-1, c)
|
||||
x = x.reshape(-1, n)
|
||||
x = torch.gather(x, dim=1, index=indexes)
|
||||
x = x.reshape(b, h, t, t)
|
||||
x = x.reshape(batch_size, num_heads, time1, time1)
|
||||
return x
|
||||
else:
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
@ -1116,13 +1111,15 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
if not torch.jit.is_tracing():
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
if not torch.jit.is_tracing():
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
@ -1235,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
src_len = k.size(0)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
if key_padding_mask is not None and not torch.jit.is_tracing():
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
key_padding_mask.size(0), bsz
|
||||
)
|
||||
@ -1246,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
if not torch.jit.is_tracing():
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
@ -1281,11 +1280,12 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
if not torch.jit.is_tracing():
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
@ -1344,7 +1344,14 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
|
||||
if not torch.jit.is_tracing():
|
||||
assert list(attn_output.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
head_dim,
|
||||
]
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
|
||||
@ -53,10 +53,9 @@ class Joiner(nn.Module):
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
|
||||
if not torch.jit.is_scripting() or not torch.jit.is_tracing():
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
|
||||
@ -152,7 +152,8 @@ class BasicNorm(torch.nn.Module):
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||
+ self.eps.exp()
|
||||
|
||||
@ -228,7 +228,18 @@ def main():
|
||||
warmup = 1.0
|
||||
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||
# encoder_model = torch.jit.script(model.encoder)
|
||||
# It throws the following error for the above statement
|
||||
#
|
||||
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||
# 11 is not supported. Please feel free to request support or
|
||||
# submit a pull request on PyTorch GitHub.
|
||||
#
|
||||
# I cannot find which statement causes the above error.
|
||||
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||
# works well for the current reworked model
|
||||
|
||||
encoder_model = model.encoder
|
||||
|
||||
torch.onnx.export(
|
||||
encoder_model,
|
||||
(x, x_lens, warmup),
|
||||
|
||||
@ -76,8 +76,8 @@ def test_encoder(
|
||||
assert encoder_inputs[0].shape == ["N", "T", 80]
|
||||
assert encoder_inputs[1].shape == ["N"]
|
||||
|
||||
x = torch.rand(1, 100, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([100])
|
||||
x = torch.rand(5, 50, 80, dtype=torch.float32)
|
||||
x_lens = torch.tensor([50, 50, 20, 30, 10])
|
||||
|
||||
encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()}
|
||||
encoder_out, encoder_out_lens = encoder_session.run(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user