Add torch.jit.export

This commit is contained in:
pkufool 2022-05-29 16:25:01 +08:00
parent 605838da55
commit 0325e3a04e
7 changed files with 415 additions and 153 deletions

View File

@ -375,9 +375,10 @@ def decode_one_batch(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming: if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size, chunk_size=params.right_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True, simulate_streaming=True,

View File

@ -109,6 +109,47 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
return parser return parser
@ -130,6 +171,7 @@ def main():
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -288,7 +288,6 @@ def get_parser():
""", """,
) )
return parser return parser

View File

@ -18,7 +18,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
@ -157,7 +157,6 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
mask = None
if self.dynamic_chunk_training: if self.dynamic_chunk_training:
assert ( assert (
@ -176,24 +175,32 @@ class Conformer(EncoderInterface):
num_left_chunks=self.num_left_chunks, num_left_chunks=self.num_left_chunks,
device=x.device, device=x.device,
) )
x = self.encoder(
x, _ = self.encoder(
x, x,
pos_emb, pos_emb,
mask=mask, mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
) # (T, N, C) ) # (T, N, C)
else:
x = self.encoder(
x,
pos_emb,
mask=None,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths return x, lengths
@torch.jit.export
def streaming_forward( def streaming_forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
states: List[Tensor],
warmup: float = 1.0, warmup: float = 1.0,
states: Optional[List[Tensor]] = None,
chunk_size: int = 16, chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
simulate_streaming: bool = False, simulate_streaming: bool = False,
@ -205,17 +212,17 @@ class Conformer(EncoderInterface):
x_lens: x_lens:
A tensor of shape (batch_size,) containing the number of frames in A tensor of shape (batch_size,) containing the number of frames in
`x` before padding. `x` before padding.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
states: states:
The decode states for previous frames which contains the cached data. The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim), a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim). (encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function. Note: states will be modified in this function.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
chunk_size: chunk_size:
The chunk size for decoding, this will be used to simulate streaming The chunk size for decoding, this will be used to simulate streaming
decoding using masking. decoding using masking.
@ -245,10 +252,6 @@ class Conformer(EncoderInterface):
lengths = (((x_lens - 1) >> 1) - 1) >> 1 lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming: if not simulate_streaming:
assert (
states is not None
), "Require cache when sending data in streaming mode"
assert ( assert (
len(states) == 2 len(states) == 2
and states[0].shape and states[0].shape
@ -272,7 +275,7 @@ class Conformer(EncoderInterface):
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x = self.encoder( x = self.encoder.chunk_forward(
embed, embed,
pos_enc, pos_enc,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
@ -282,7 +285,6 @@ class Conformer(EncoderInterface):
) # (T, B, F) ) # (T, B, F)
else: else:
assert states is None
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x) x = self.encoder_embed(x)
@ -392,9 +394,7 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
states: Optional[List[Tensor]] = None, ) -> Tensor:
left_context: int = 0,
) -> Tuple[Tensor]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -405,20 +405,9 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E),for streaming decoding it is (N, 2*(S+left_context)-1, E). pos_emb: (N, 2*S-1, E)
src_mask: (S, S). src_mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
@ -440,15 +429,82 @@ class ConformerEncoderLayer(nn.Module):
# macaron style feed forward module # macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
key = src # multi-headed self-attention module
val = src src_att = self.self_attn(
if not self.training and states is not None: src,
# src: [chunk_size, N, F] e.g. [8, 41, 512] src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = src + self.dropout(src_att)
# convolution module
conv, _ = self.conv_module(src)
src = src + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (left_context, src.size(1), src.size(2))
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
key = torch.cat([states[0], src], dim=0) key = torch.cat([states[0], src], dim=0)
val = key val = key
states[0] = key[-left_context:, ...] states[0] = key[-left_context:, ...]
else:
assert left_context == 0
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att = self.self_attn(
@ -464,11 +520,9 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
# convolution module # convolution module
if not self.training and states is not None:
conv, conv_cache = self.conv_module(src, states[1]) conv, conv_cache = self.conv_module(src, states[1])
states[1] = conv_cache states[1] = conv_cache
else:
conv = self.conv_module(src)
src = src + self.dropout(conv) src = src + self.dropout(conv)
# feed forward module # feed forward module
@ -476,9 +530,6 @@ class ConformerEncoderLayer(nn.Module):
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src return src
@ -511,8 +562,6 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
states: Optional[List[Tensor]] = None,
left_context: int = 0,
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -523,20 +572,10 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). pos_emb: (N, 2*S-1, E)
mask: (S, S). mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
@ -544,28 +583,79 @@ class ConformerEncoder(nn.Module):
""" """
output = src output = src
if self.training:
assert left_context == 0
assert states is None
else:
assert left_context >= 0
for layer_index, mod in enumerate(self.layers): for layer_index, mod in enumerate(self.layers):
cache = (
None
if states is None
else [states[0][layer_index], states[1][layer_index]]
)
output = mod( output = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
)
return output
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (
len(self.layers),
left_context,
src.size(1),
src.size(2),
)
assert states[1].size(0) == len(self.layers)
output = src
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output = mod.chunk_forward(
output,
pos_emb,
states=cache, states=cache,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
left_context=left_context, left_context=left_context,
) )
if states is not None:
states[0][layer_index] = cache[0] states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1] states[1][layer_index] = cache[1]
@ -1216,7 +1306,7 @@ class ConvolutionModule(nn.Module):
self, self,
x: Tensor, x: Tensor,
cache: Optional[Tensor] = None, cache: Optional[Tensor] = None,
) -> Union[Tensor, Tuple[Tensor, Tensor]]: ) -> Tuple[Tensor, Tensor]:
"""Compute convolution module. """Compute convolution module.
Args: Args:
@ -1260,9 +1350,11 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
return ( # torch.jit.script requires return types be the same as annotated above
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) if cache is None:
) cache = torch.empty(0)
return x.permute(2, 0, 1), cache
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):

View File

@ -335,9 +335,10 @@ def decode_one_batch(
) )
if params.simulate_streaming: if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size, chunk_size=params.right_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True, simulate_streaming=True,

View File

@ -124,6 +124,47 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
return parser return parser

View File

@ -18,7 +18,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -155,7 +155,6 @@ class Conformer(Transformer):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
mask = None
if self.dynamic_chunk_training: if self.dynamic_chunk_training:
assert ( assert (
@ -174,10 +173,13 @@ class Conformer(Transformer):
num_left_chunks=self.num_left_chunks, num_left_chunks=self.num_left_chunks,
device=x.device, device=x.device,
) )
x = self.encoder(
x, _ = self.encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, N, C) ) # (T, N, C)
else:
x = self.encoder(
x, pos_emb, mask=None, src_key_padding_mask=src_key_padding_mask
) # (T, N, C)
if self.normalize_before: if self.normalize_before:
x = self.after_norm(x) x = self.after_norm(x)
@ -187,11 +189,12 @@ class Conformer(Transformer):
return logits, lengths return logits, lengths
@torch.jit.export
def streaming_forward( def streaming_forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
states: Optional[List[torch.Tensor]] = None, states: List[torch.Tensor],
chunk_size: int = 16, chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
simulate_streaming: bool = False, simulate_streaming: bool = False,
@ -209,7 +212,7 @@ class Conformer(Transformer):
a shape of (encoder_layers, left_context, batch, attention_dim), a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim). (encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function. Note: states will be modified in this function.
chunk_size: chunk_size:
The chunk size for decoding, this will be used to simulate streaming The chunk size for decoding, this will be used to simulate streaming
decoding using masking. decoding using masking.
@ -239,10 +242,6 @@ class Conformer(Transformer):
lengths = (((x_lens - 1) >> 1) - 1) >> 1 lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming: if not simulate_streaming:
assert (
states is not None
), "Require cache when sending data in streaming mode"
assert ( assert (
len(states) == 2 len(states) == 2
and states[0].shape and states[0].shape
@ -266,17 +265,14 @@ class Conformer(Transformer):
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x = self.encoder( x = self.encoder.chunk_forward(
embed, embed,
pos_enc, pos_enc,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
states=states, states=states,
left_context=left_context, left_context=left_context,
) # (T, B, F) ) # (T, B, F)
else: else:
assert states is None
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
@ -389,8 +385,6 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
states: Optional[List[Tensor]] = None,
left_context: int = 0,
) -> Tensor: ) -> Tensor:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -400,19 +394,95 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
src_att = self.self_attn(
src,
src,
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src, _ = self.conv_module(src)
src = residual + self.dropout(src)
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
states: states:
The decode states for previous frames which contains the cached data. The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim), a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim). (encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function. Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
left_context: left context (in frames) used during streaming decoding. left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). pos_emb: (N, 2*(S+left_context)-1, E).
src_mask: (S, S). src_mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
@ -433,15 +503,9 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
src = self.norm_mha(src) src = self.norm_mha(src)
key = src
val = src
if not self.training and states is not None:
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([states[0], src], dim=0) key = torch.cat([states[0], src], dim=0)
val = key val = key
states[0] = key[-left_context:, ...] states[0] = key[-left_context:, ...]
else:
assert left_context == 0
src_att = self.self_attn( src_att = self.self_attn(
src, src,
@ -461,11 +525,8 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
src = self.norm_conv(src) src = self.norm_conv(src)
if not self.training and states is not None:
src, conv_cache = self.conv_module(src, states[1]) src, conv_cache = self.conv_module(src, states[1])
states[1] = conv_cache states[1] = conv_cache
else:
src = self.conv_module(src)
src = residual + self.dropout(src) src = residual + self.dropout(src)
if not self.normalize_before: if not self.normalize_before:
@ -513,8 +574,6 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
states: Optional[List[Tensor]] = None,
left_context: int = 0,
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -523,21 +582,11 @@ class ConformerEncoder(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional). mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: If not None, states will be modified in this function.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E). pos_emb: (N, 2*S-1, E).
mask: (S, S). mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
@ -545,27 +594,63 @@ class ConformerEncoder(nn.Module):
""" """
output = src output = src
if self.training:
assert left_context == 0
assert states is None
else:
assert left_context >= 0
for layer_index, mod in enumerate(self.layers): for layer_index, mod in enumerate(self.layers):
cache = (
None
if states is None
else [states[0][layer_index], states[1][layer_index]]
)
output = mod( output = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
)
return output
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
assert not self.training
output = src
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output = mod.chunk_forward(
output,
pos_emb,
states=cache, states=cache,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
left_context=left_context, left_context=left_context,
) )
if states is not None:
states[0][layer_index] = cache[0] states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1] states[1][layer_index] = cache[1]
@ -1186,7 +1271,7 @@ class ConvolutionModule(nn.Module):
def forward( def forward(
self, x: Tensor, cache: Optional[Tensor] = None self, x: Tensor, cache: Optional[Tensor] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]: ) -> Tuple[Tensor, Tensor]:
"""Compute convolution module. """Compute convolution module.
Args: Args:
@ -1227,9 +1312,10 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)
return ( if cache is None:
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache) cache = torch.empty(0)
)
return x.permute(2, 0, 1), cache
class Swish(torch.nn.Module): class Swish(torch.nn.Module):