From 49aaaf80212913c445eaa696b8c537d401247508 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 28 Jul 2022 15:52:51 +0800 Subject: [PATCH] Minor fixes. --- .../pruned_transducer_stateless2/conformer.py | 85 ++++++++++--------- .../pruned_transducer_stateless2/joiner.py | 7 +- .../pruned_transducer_stateless2/scaling.py | 3 +- .../pruned_transducer_stateless3/export.py | 11 +++ .../onnx_check.py | 4 +- 5 files changed, 64 insertions(+), 46 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 52c227cf8..00549c086 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -155,7 +155,8 @@ class Conformer(EncoderInterface): # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 1) >> 1) - 1) >> 1 - assert x.size(0) == lengths.max().item() + if not torch.jit.is_tracing(): + assert x.size(0) == lengths.max().item() src_key_padding_mask = make_pad_mask(lengths) @@ -787,6 +788,14 @@ class RelPositionalEncoding(torch.nn.Module): ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() + if torch.jit.is_tracing(): + # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., + # It assumes that the maximum input won't have more than + # 10k frames. + # + # TODO(fangjun): Use torch.jit.script() for this module + max_len = 10000 + self.d_model = d_model self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None @@ -1006,34 +1015,20 @@ class RelPositionMultiheadAttention(nn.Module): (batch_size, num_heads, time1, n) = x.shape time2 = time1 + left_context - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" + if not torch.jit.is_tracing(): + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" - if torch.jit.is_scripting() or torch.jit.is_tracing(): - x = x.contiguous() - b = x.size(0) - h = x.size(1) - t = x.size(2) - c = x.size(3) + if torch.jit.is_tracing(): + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(time1) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols - bh = b * h - - if False: - rows = torch.arange(start=t - 1, end=-1, step=-1).unsqueeze(-1) - cols = torch.arange(t) - indexes = rows + cols - # onnx does not support torch.tile - indexes = torch.tile(indexes, (bh, 1)) - else: - rows = torch.arange(start=t - 1, end=-1, step=-1) - cols = torch.arange(t) - rows = torch.cat([rows] * bh).unsqueeze(-1) - indexes = rows + cols - - x = x.reshape(-1, c) + x = x.reshape(-1, n) x = torch.gather(x, dim=1, index=indexes) - x = x.reshape(b, h, t, t) + x = x.reshape(batch_size, num_heads, time1, time1) return x else: # Note: TorchScript requires explicit arg for stride() @@ -1116,13 +1111,15 @@ class RelPositionMultiheadAttention(nn.Module): """ tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + if not torch.jit.is_tracing(): + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" + if not torch.jit.is_tracing(): + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" scaling = float(head_dim) ** -0.5 @@ -1235,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module): src_len = k.size(0) - if key_padding_mask is not None: + if key_padding_mask is not None and not torch.jit.is_tracing(): assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz ) @@ -1246,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module): q = q.transpose(0, 1) # (batch, time1, head, d_k) pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 + if not torch.jit.is_tracing(): + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) p = p.permute(0, 2, 3, 1) @@ -1281,11 +1280,12 @@ class RelPositionMultiheadAttention(nn.Module): bsz * num_heads, tgt_len, -1 ) - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] + if not torch.jit.is_tracing(): + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] if attn_mask is not None: if attn_mask.dtype == torch.bool: @@ -1344,7 +1344,14 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + + if not torch.jit.is_tracing(): + assert list(attn_output.size()) == [ + bsz * num_heads, + tgt_len, + head_dim, + ] + attn_output = ( attn_output.transpose(0, 1) .contiguous() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 998d575b6..b916addf0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -53,10 +53,9 @@ class Joiner(nn.Module): Return a tensor of shape (N, T, s_range, C). """ - if not torch.jit.is_scripting() or not torch.jit.is_tracing(): - assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape == decoder_out.shape + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.ndim in (2, 4) + assert encoder_out.shape == decoder_out.shape if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index e7c2e55f4..2b44dc649 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -152,7 +152,8 @@ class BasicNorm(torch.nn.Module): self.register_buffer("eps", torch.tensor(eps).log().detach()) def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels + if not torch.jit.is_tracing(): + assert x.shape[self.channel_dim] == self.num_channels scales = ( torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + self.eps.exp() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index d233ddf2f..1fa39ceb7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -228,7 +228,18 @@ def main(): warmup = 1.0 encoder_filename = params.exp_dir / "encoder.onnx" # encoder_model = torch.jit.script(model.encoder) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + encoder_model = model.encoder + torch.onnx.export( encoder_model, (x, x_lens, warmup), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 3c3fbb1f5..d7379c22e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -76,8 +76,8 @@ def test_encoder( assert encoder_inputs[0].shape == ["N", "T", 80] assert encoder_inputs[1].shape == ["N"] - x = torch.rand(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100]) + x = torch.rand(5, 50, 80, dtype=torch.float32) + x_lens = torch.tensor([50, 50, 20, 30, 10]) encoder_inputs = {"x": x.numpy(), "x_lens": x_lens.numpy()} encoder_out, encoder_out_lens = encoder_session.run(