mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix exporting non-streaming zipformer to ONNX via torch.jit.trace()
This commit is contained in:
parent
f43c44236f
commit
dd91c89f28
@ -13,7 +13,7 @@ as an example to show how to use this file.
|
|||||||
|
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
@ -27,12 +27,29 @@ popd
|
|||||||
|
|
||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./zipformer/export-onnx.py \ 148 ↵
|
./zipformer/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/
|
--exp-dir $repo/exp \
|
||||||
|
\
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1"
|
||||||
|
|
||||||
It will generate the following 3 files inside $repo/exp:
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
@ -299,7 +316,7 @@ def export_encoder_model_onnx(
|
|||||||
"model_type": "zipformer",
|
"model_type": "zipformer",
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"model_author": "k2-fsa",
|
"model_author": "k2-fsa",
|
||||||
"comment": "stateless7",
|
"comment": "zipformer",
|
||||||
}
|
}
|
||||||
logging.info(f"meta_data: {meta_data}")
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class Transducer(nn.Module):
|
|||||||
encoder:
|
encoder:
|
||||||
It is the transcription network in the paper. Its accepts
|
It is the transcription network in the paper. Its accepts
|
||||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||||
`logit_lens` of shape (N,).
|
`logit_lens` of shape (N,).
|
||||||
decoder:
|
decoder:
|
||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
|
|||||||
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless7/onnx_pretrained.py
|
||||||
@ -26,6 +26,18 @@ import torch.nn as nn
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
||||||
|
# 14 is not supported. Please feel free to request support or submit
|
||||||
|
# a pull request on PyTorch GitHub.
|
||||||
|
#
|
||||||
|
# The following function is to solve the above error when exporting
|
||||||
|
# models to ONNX via torch.jit.trace()
|
||||||
|
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
|
return torch.logaddexp(x, y)
|
||||||
|
else:
|
||||||
|
return (x.exp() + y.exp()).log()
|
||||||
|
|
||||||
class PiecewiseLinear(object):
|
class PiecewiseLinear(object):
|
||||||
"""
|
"""
|
||||||
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
||||||
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
|
|||||||
|
|
||||||
def __float__(self):
|
def __float__(self):
|
||||||
batch_count = self.batch_count
|
batch_count = self.batch_count
|
||||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return float(self.default)
|
return float(self.default)
|
||||||
else:
|
else:
|
||||||
ans = self.schedule(self.batch_count)
|
ans = self.schedule(self.batch_count)
|
||||||
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
def softmax(x: Tensor, dim: int):
|
def softmax(x: Tensor, dim: int):
|
||||||
if not x.requires_grad or torch.jit.is_scripting():
|
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return x.softmax(dim=dim)
|
return x.softmax(dim=dim)
|
||||||
|
|
||||||
return SoftmaxFunction.apply(x, dim)
|
return SoftmaxFunction.apply(x, dim)
|
||||||
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
|
|||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting() or not self.training:
|
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||||
return x
|
return x
|
||||||
return scale_grad(x, self.alpha)
|
return scale_grad(x, self.alpha)
|
||||||
|
|
||||||
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def _no_op(x: Tensor) -> Tensor:
|
def _no_op(x: Tensor) -> Tensor:
|
||||||
if (torch.jit.is_scripting()):
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
# a no-op function that will have a node in the autograd graph,
|
# a no-op function that will have a node in the autograd graph,
|
||||||
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1).
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return x * torch.sigmoid(x - 1.0)
|
return x * torch.sigmoid(x - 1.0)
|
||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swoosh-L activation.
|
"""Return Swoosh-L activation.
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
return k2.swoosh_l_forward(x)
|
return k2.swoosh_l_forward(x)
|
||||||
else:
|
else:
|
||||||
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swoosh-R activation.
|
"""Return Swoosh-R activation.
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
return k2.swoosh_r_forward(x)
|
return k2.swoosh_r_forward(x)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting() or not self.training:
|
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||||
return self.forward_internal(x)
|
return self.forward_internal(x)
|
||||||
layerdrop_rate = float(self.layerdrop_rate)
|
layerdrop_rate = float(self.layerdrop_rate)
|
||||||
|
|
||||||
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
x_lens = (x_lens - 7) // 2
|
x_lens = (x_lens - 7) // 2
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
|
|||||||
@ -258,7 +258,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
if not self.causal:
|
if not self.causal:
|
||||||
return -1, -1
|
return -1, -1
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
assert len(self.chunk_size) == 1, self.chunk_size
|
assert len(self.chunk_size) == 1, self.chunk_size
|
||||||
chunk_size = self.chunk_size[0]
|
chunk_size = self.chunk_size[0]
|
||||||
else:
|
else:
|
||||||
@ -267,7 +267,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
if chunk_size == -1:
|
if chunk_size == -1:
|
||||||
left_context_chunks = -1
|
left_context_chunks = -1
|
||||||
else:
|
else:
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
assert len(self.left_context_frames) == 1, self.left_context_frames
|
assert len(self.left_context_frames) == 1, self.left_context_frames
|
||||||
left_context_frames = self.left_context_frames[0]
|
left_context_frames = self.left_context_frames[0]
|
||||||
else:
|
else:
|
||||||
@ -301,14 +301,14 @@ class Zipformer2(EncoderInterface):
|
|||||||
of frames in `embeddings` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
feature_masks = [1.0] * len(self.encoder_dim)
|
feature_masks = [1.0] * len(self.encoder_dim)
|
||||||
else:
|
else:
|
||||||
feature_masks = self.get_feature_masks(x)
|
feature_masks = self.get_feature_masks(x)
|
||||||
|
|
||||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
# Not support exporting a model for simulating streaming decoding
|
# Not support exporting a model for simulating streaming decoding
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
else:
|
else:
|
||||||
@ -334,7 +334,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
x = self.downsample_output(x)
|
x = self.downsample_output(x)
|
||||||
# class Downsample has this rounding behavior..
|
# class Downsample has this rounding behavior..
|
||||||
assert self.output_downsampling_factor == 2
|
assert self.output_downsampling_factor == 2
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
lengths = (x_lens + 1) // 2
|
lengths = (x_lens + 1) // 2
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -372,7 +372,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
# t is frame index, shape (seq_len,)
|
# t is frame index, shape (seq_len,)
|
||||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||||
# c is chunk index for each frame, shape (seq_len,)
|
# c is chunk index for each frame, shape (seq_len,)
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
c = t // chunk_size
|
c = t // chunk_size
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -650,7 +650,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return None
|
return None
|
||||||
batch_size = x.shape[1]
|
batch_size = x.shape[1]
|
||||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||||
@ -695,7 +695,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
# dropout rate for non-feedforward submodules
|
# dropout rate for non-feedforward submodules
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
attention_skip_rate = 0.0
|
attention_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||||
@ -713,7 +713,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||||
|
|
||||||
selected_attn_weights = attn_weights[0:1]
|
selected_attn_weights = attn_weights[0:1]
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
elif not self.training and random.random() < float(self.const_attention_rate):
|
elif not self.training and random.random() < float(self.const_attention_rate):
|
||||||
# Make attention weights constant. The intention is to
|
# Make attention weights constant. The intention is to
|
||||||
@ -732,7 +732,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
conv_skip_rate = 0.0
|
conv_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||||
@ -740,7 +740,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask),
|
src_key_padding_mask=src_key_padding_mask),
|
||||||
conv_skip_rate)
|
conv_skip_rate)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
ff2_skip_rate = 0.0
|
ff2_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
||||||
@ -754,7 +754,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
conv_skip_rate = 0.0
|
conv_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||||
@ -762,7 +762,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask),
|
src_key_padding_mask=src_key_padding_mask),
|
||||||
conv_skip_rate)
|
conv_skip_rate)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
ff3_skip_rate = 0.0
|
ff3_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
||||||
@ -968,7 +968,7 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
pos_emb = self.encoder_pos(src)
|
pos_emb = self.encoder_pos(src)
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
if not torch.jit.is_scripting():
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
@ -980,7 +980,7 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.jit.is_scripting():
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -1073,7 +1073,7 @@ class BypassModule(nn.Module):
|
|||||||
# or (batch_size, num_channels,). This is actually the
|
# or (batch_size, num_channels,). This is actually the
|
||||||
# scale on the non-residual term, so 0 correponds to bypassing
|
# scale on the non-residual term, so 0 correponds to bypassing
|
||||||
# this module.
|
# this module.
|
||||||
if torch.jit.is_scripting() or not self.training:
|
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||||
return self.bypass_scale
|
return self.bypass_scale
|
||||||
else:
|
else:
|
||||||
ans = limit_param_value(self.bypass_scale,
|
ans = limit_param_value(self.bypass_scale,
|
||||||
@ -1524,7 +1524,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
attn_scores = torch.matmul(q, k)
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
use_pos_scores = False
|
use_pos_scores = False
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
# We can't put random.random() in the same line
|
# We can't put random.random() in the same line
|
||||||
use_pos_scores = True
|
use_pos_scores = True
|
||||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
@ -1561,7 +1561,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
attn_scores = attn_scores + pos_scores
|
attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
elif self.training and random.random() < 0.1:
|
elif self.training and random.random() < 0.1:
|
||||||
# This is a harder way of limiting the attention scores to not be
|
# This is a harder way of limiting the attention scores to not be
|
||||||
@ -1604,7 +1604,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# half-precision output for backprop purposes.
|
# half-precision output for backprop purposes.
|
||||||
attn_weights = softmax(attn_scores, dim=-1)
|
attn_weights = softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
elif random.random() < 0.001 and not self.training:
|
elif random.random() < 0.001 and not self.training:
|
||||||
self._print_attn_entropy(attn_weights)
|
self._print_attn_entropy(attn_weights)
|
||||||
@ -2146,7 +2146,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
if src_key_padding_mask is not None:
|
if src_key_padding_mask is not None:
|
||||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
|
||||||
if not torch.jit.is_scripting() and chunk_size >= 0:
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
|
||||||
# Not support exporting a model for simulated streaming decoding
|
# Not support exporting a model for simulated streaming decoding
|
||||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||||
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user