Add torch.jit.export

This commit is contained in:
pkufool 2022-05-29 16:25:01 +08:00
parent 605838da55
commit 0325e3a04e
7 changed files with 415 additions and 153 deletions

View File

@ -375,9 +375,10 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True,

View File

@ -109,6 +109,47 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
return parser
@ -130,6 +171,7 @@ def main():
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -288,7 +288,6 @@ def get_parser():
""",
)
return parser

View File

@ -18,7 +18,7 @@
import copy
import math
import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
@ -157,7 +157,6 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths)
mask = None
if self.dynamic_chunk_training:
assert (
@ -176,24 +175,32 @@ class Conformer(EncoderInterface):
num_left_chunks=self.num_left_chunks,
device=x.device,
)
x, _ = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
else:
x = self.encoder(
x,
pos_emb,
mask=None,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
@torch.jit.export
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: List[Tensor],
warmup: float = 1.0,
states: Optional[List[Tensor]] = None,
chunk_size: int = 16,
left_context: int = 64,
simulate_streaming: bool = False,
@ -205,17 +212,17 @@ class Conformer(EncoderInterface):
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
Note: states will be modified in this function.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
@ -245,10 +252,6 @@ class Conformer(EncoderInterface):
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming:
assert (
states is not None
), "Require cache when sending data in streaming mode"
assert (
len(states) == 2
and states[0].shape
@ -272,7 +275,7 @@ class Conformer(EncoderInterface):
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x = self.encoder(
x = self.encoder.chunk_forward(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
@ -282,7 +285,6 @@ class Conformer(EncoderInterface):
) # (T, B, F)
else:
assert states is None
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
@ -392,9 +394,7 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
states: Optional[List[Tensor]] = None,
left_context: int = 0,
) -> Tuple[Tensor]:
) -> Tensor:
"""
Pass the input through the encoder layer.
@ -405,20 +405,9 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E),for streaming decoding it is (N, 2*(S+left_context)-1, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
@ -440,15 +429,82 @@ class ConformerEncoderLayer(nn.Module):
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
key = src
val = src
if not self.training and states is not None:
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([states[0], src], dim=0)
val = key
states[0] = key[-left_context:, ...]
else:
assert left_context == 0
# multi-headed self-attention module
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = src + self.dropout(src_att)
# convolution module
conv, _ = self.conv_module(src)
src = src + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (left_context, src.size(1), src.size(2))
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
key = torch.cat([states[0], src], dim=0)
val = key
states[0] = key[-left_context:, ...]
# multi-headed self-attention module
src_att = self.self_attn(
@ -464,11 +520,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
if not self.training and states is not None:
conv, conv_cache = self.conv_module(src, states[1])
states[1] = conv_cache
else:
conv = self.conv_module(src)
conv, conv_cache = self.conv_module(src, states[1])
states[1] = conv_cache
src = src + self.dropout(conv)
# feed forward module
@ -476,9 +530,6 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src
@ -511,8 +562,6 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
states: Optional[List[Tensor]] = None,
left_context: int = 0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -523,20 +572,10 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
@ -544,30 +583,81 @@ class ConformerEncoder(nn.Module):
"""
output = src
if self.training:
assert left_context == 0
assert states is None
else:
assert left_context >= 0
for layer_index, mod in enumerate(self.layers):
cache = (
None
if states is None
else [states[0][layer_index], states[1][layer_index]]
)
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
return output
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (
len(self.layers),
left_context,
src.size(1),
src.size(2),
)
assert states[1].size(0) == len(self.layers)
output = src
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output = mod.chunk_forward(
output,
pos_emb,
states=cache,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
left_context=left_context,
)
if states is not None:
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output
@ -1216,7 +1306,7 @@ class ConvolutionModule(nn.Module):
self,
x: Tensor,
cache: Optional[Tensor] = None,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
@ -1260,9 +1350,11 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time)
return (
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
)
# torch.jit.script requires return types be the same as annotated above
if cache is None:
cache = torch.empty(0)
return x.permute(2, 0, 1), cache
class Conv2dSubsampling(nn.Module):

View File

@ -335,9 +335,10 @@ def decode_one_batch(
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True,

View File

@ -124,6 +124,47 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
return parser

View File

@ -18,7 +18,7 @@
import copy
import math
import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import torch
from torch import Tensor, nn
@ -155,7 +155,6 @@ class Conformer(Transformer):
assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths)
mask = None
if self.dynamic_chunk_training:
assert (
@ -174,10 +173,13 @@ class Conformer(Transformer):
num_left_chunks=self.num_left_chunks,
device=x.device,
)
x, _ = self.encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, N, C)
x = self.encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, N, C)
else:
x = self.encoder(
x, pos_emb, mask=None, src_key_padding_mask=src_key_padding_mask
) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
@ -187,11 +189,12 @@ class Conformer(Transformer):
return logits, lengths
@torch.jit.export
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[torch.Tensor]] = None,
states: List[torch.Tensor],
chunk_size: int = 16,
left_context: int = 64,
simulate_streaming: bool = False,
@ -209,7 +212,7 @@ class Conformer(Transformer):
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
Note: states will be modified in this function.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
@ -239,10 +242,6 @@ class Conformer(Transformer):
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming:
assert (
states is not None
), "Require cache when sending data in streaming mode"
assert (
len(states) == 2
and states[0].shape
@ -266,17 +265,14 @@ class Conformer(Transformer):
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x = self.encoder(
x = self.encoder.chunk_forward(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
states=states,
left_context=left_context,
) # (T, B, F)
else:
assert states is None
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
@ -389,8 +385,6 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
states: Optional[List[Tensor]] = None,
left_context: int = 0,
) -> Tensor:
"""
Pass the input through the encoder layer.
@ -400,19 +394,95 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src, _ = self.conv_module(src)
src = residual + self.dropout(src)
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
pos_emb: (N, 2*(S+left_context)-1, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
@ -433,15 +503,9 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before:
src = self.norm_mha(src)
key = src
val = src
if not self.training and states is not None:
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([states[0], src], dim=0)
val = key
states[0] = key[-left_context:, ...]
else:
assert left_context == 0
key = torch.cat([states[0], src], dim=0)
val = key
states[0] = key[-left_context:, ...]
src_att = self.self_attn(
src,
@ -461,11 +525,8 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before:
src = self.norm_conv(src)
if not self.training and states is not None:
src, conv_cache = self.conv_module(src, states[1])
states[1] = conv_cache
else:
src = self.conv_module(src)
src, conv_cache = self.conv_module(src, states[1])
states[1] = conv_cache
src = residual + self.dropout(src)
if not self.normalize_before:
@ -513,8 +574,6 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
states: Optional[List[Tensor]] = None,
left_context: int = 0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -523,21 +582,11 @@ class ConformerEncoder(nn.Module):
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
pos_emb: (N, 2*S-1, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
@ -545,29 +594,65 @@ class ConformerEncoder(nn.Module):
"""
output = src
if self.training:
assert left_context == 0
assert states is None
else:
assert left_context >= 0
for layer_index, mod in enumerate(self.layers):
cache = (
None
if states is None
else [states[0][layer_index], states[1][layer_index]]
)
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
return output
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
assert not self.training
output = src
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output = mod.chunk_forward(
output,
pos_emb,
states=cache,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
left_context=left_context,
)
if states is not None:
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output
@ -1186,7 +1271,7 @@ class ConvolutionModule(nn.Module):
def forward(
self, x: Tensor, cache: Optional[Tensor] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
@ -1227,9 +1312,10 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time)
return (
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
)
if cache is None:
cache = torch.empty(0)
return x.permute(2, 0, 1), cache
class Swish(torch.nn.Module):