mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Add torch.jit.export
This commit is contained in:
parent
605838da55
commit
0325e3a04e
@ -375,9 +375,10 @@ def decode_one_batch(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
states=[],
|
||||
chunk_size=params.right_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
|
@ -109,6 +109,47 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True.
|
||||
Note: not needed here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="""Chunk length of dynamic training, the chunk size would be either
|
||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||
Note: not needed for here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""How many left context can be seen in chunks when calculating attention.
|
||||
Note: not needed here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -130,6 +171,7 @@ def main():
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
@ -288,7 +288,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
@ -157,7 +157,6 @@ class Conformer(EncoderInterface):
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
mask = None
|
||||
|
||||
if self.dynamic_chunk_training:
|
||||
assert (
|
||||
@ -176,24 +175,32 @@ class Conformer(EncoderInterface):
|
||||
num_left_chunks=self.num_left_chunks,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
x, _ = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
else:
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=None,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
return x, lengths
|
||||
|
||||
@torch.jit.export
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: List[Tensor],
|
||||
warmup: float = 1.0,
|
||||
states: Optional[List[Tensor]] = None,
|
||||
chunk_size: int = 16,
|
||||
left_context: int = 64,
|
||||
simulate_streaming: bool = False,
|
||||
@ -205,17 +212,17 @@ class Conformer(EncoderInterface):
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: If not None, states will be modified in this function.
|
||||
Note: states will be modified in this function.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
@ -245,10 +252,6 @@ class Conformer(EncoderInterface):
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
if not simulate_streaming:
|
||||
assert (
|
||||
states is not None
|
||||
), "Require cache when sending data in streaming mode"
|
||||
|
||||
assert (
|
||||
len(states) == 2
|
||||
and states[0].shape
|
||||
@ -272,7 +275,7 @@ class Conformer(EncoderInterface):
|
||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
|
||||
x = self.encoder(
|
||||
x = self.encoder.chunk_forward(
|
||||
embed,
|
||||
pos_enc,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
@ -282,7 +285,6 @@ class Conformer(EncoderInterface):
|
||||
) # (T, B, F)
|
||||
|
||||
else:
|
||||
assert states is None
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
x = self.encoder_embed(x)
|
||||
@ -392,9 +394,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
states: Optional[List[Tensor]] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor]:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
@ -405,20 +405,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: If not None, states will be modified in this function.
|
||||
left_context: left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E),for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
@ -440,15 +429,82 @@ class ConformerEncoderLayer(nn.Module):
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
key = src
|
||||
val = src
|
||||
if not self.training and states is not None:
|
||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||
key = torch.cat([states[0], src], dim=0)
|
||||
val = key
|
||||
states[0] = key[-left_context:, ...]
|
||||
else:
|
||||
assert left_context == 0
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
conv, _ = self.conv_module(src)
|
||||
src = src + self.dropout(conv)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
return src
|
||||
|
||||
@torch.jit.export
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
states: List[Tensor],
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
left_context: left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
assert states[0].shape == (left_context, src.size(1), src.size(2))
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
key = torch.cat([states[0], src], dim=0)
|
||||
val = key
|
||||
states[0] = key[-left_context:, ...]
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
@ -464,11 +520,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
if not self.training and states is not None:
|
||||
conv, conv_cache = self.conv_module(src, states[1])
|
||||
states[1] = conv_cache
|
||||
else:
|
||||
conv = self.conv_module(src)
|
||||
conv, conv_cache = self.conv_module(src, states[1])
|
||||
states[1] = conv_cache
|
||||
|
||||
src = src + self.dropout(conv)
|
||||
|
||||
# feed forward module
|
||||
@ -476,9 +530,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
return src
|
||||
|
||||
|
||||
@ -511,8 +562,6 @@ class ConformerEncoder(nn.Module):
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
states: Optional[List[Tensor]] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -523,20 +572,10 @@ class ConformerEncoder(nn.Module):
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: If not None, states will be modified in this function.
|
||||
left_context: left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
@ -544,30 +583,81 @@ class ConformerEncoder(nn.Module):
|
||||
"""
|
||||
output = src
|
||||
|
||||
if self.training:
|
||||
assert left_context == 0
|
||||
assert states is None
|
||||
else:
|
||||
assert left_context >= 0
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
cache = (
|
||||
None
|
||||
if states is None
|
||||
else [states[0][layer_index], states[1][layer_index]]
|
||||
)
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@torch.jit.export
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
states: List[Tensor],
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
left_context: left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
assert states[0].shape == (
|
||||
len(self.layers),
|
||||
left_context,
|
||||
src.size(1),
|
||||
src.size(2),
|
||||
)
|
||||
assert states[1].size(0) == len(self.layers)
|
||||
|
||||
output = src
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
cache = [states[0][layer_index], states[1][layer_index]]
|
||||
output = mod.chunk_forward(
|
||||
output,
|
||||
pos_emb,
|
||||
states=cache,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
left_context=left_context,
|
||||
)
|
||||
if states is not None:
|
||||
states[0][layer_index] = cache[0]
|
||||
states[1][layer_index] = cache[1]
|
||||
states[0][layer_index] = cache[0]
|
||||
states[1][layer_index] = cache[1]
|
||||
|
||||
return output
|
||||
|
||||
@ -1216,7 +1306,7 @@ class ConvolutionModule(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
cache: Optional[Tensor] = None,
|
||||
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
@ -1260,9 +1350,11 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return (
|
||||
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
|
||||
)
|
||||
# torch.jit.script requires return types be the same as annotated above
|
||||
if cache is None:
|
||||
cache = torch.empty(0)
|
||||
|
||||
return x.permute(2, 0, 1), cache
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
|
@ -335,9 +335,10 @@ def decode_one_batch(
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
states=[],
|
||||
chunk_size=params.right_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
|
@ -124,6 +124,47 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True.
|
||||
Note: not needed here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="""Chunk length of dynamic training, the chunk size would be either
|
||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||
Note: not needed for here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""How many left context can be seen in chunks when calculating attention.
|
||||
Note: not needed here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@ -155,7 +155,6 @@ class Conformer(Transformer):
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
mask = None
|
||||
|
||||
if self.dynamic_chunk_training:
|
||||
assert (
|
||||
@ -174,10 +173,13 @@ class Conformer(Transformer):
|
||||
num_left_chunks=self.num_left_chunks,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
x, _ = self.encoder(
|
||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, N, C)
|
||||
x = self.encoder(
|
||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, N, C)
|
||||
else:
|
||||
x = self.encoder(
|
||||
x, pos_emb, mask=None, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, N, C)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
@ -187,11 +189,12 @@ class Conformer(Transformer):
|
||||
|
||||
return logits, lengths
|
||||
|
||||
@torch.jit.export
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: Optional[List[torch.Tensor]] = None,
|
||||
states: List[torch.Tensor],
|
||||
chunk_size: int = 16,
|
||||
left_context: int = 64,
|
||||
simulate_streaming: bool = False,
|
||||
@ -209,7 +212,7 @@ class Conformer(Transformer):
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: If not None, states will be modified in this function.
|
||||
Note: states will be modified in this function.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
@ -239,10 +242,6 @@ class Conformer(Transformer):
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
if not simulate_streaming:
|
||||
assert (
|
||||
states is not None
|
||||
), "Require cache when sending data in streaming mode"
|
||||
|
||||
assert (
|
||||
len(states) == 2
|
||||
and states[0].shape
|
||||
@ -266,17 +265,14 @@ class Conformer(Transformer):
|
||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
|
||||
x = self.encoder(
|
||||
x = self.encoder.chunk_forward(
|
||||
embed,
|
||||
pos_enc,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
states=states,
|
||||
left_context=left_context,
|
||||
) # (T, B, F)
|
||||
|
||||
else:
|
||||
assert states is None
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
@ -389,8 +385,6 @@ class ConformerEncoderLayer(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
states: Optional[List[Tensor]] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
@ -400,19 +394,95 @@ class ConformerEncoderLayer(nn.Module):
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E).
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
# macaron style feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
src = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(src)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = residual + self.dropout(src_att)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
|
||||
# convolution module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
|
||||
src, _ = self.conv_module(src)
|
||||
src = residual + self.dropout(src)
|
||||
|
||||
if not self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
|
||||
# feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
|
||||
if self.normalize_before:
|
||||
src = self.norm_final(src)
|
||||
|
||||
return src
|
||||
|
||||
@torch.jit.export
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
states: List[Tensor],
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: If not None, states will be modified in this function.
|
||||
Note: states will be modified in this function.
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
left_context: left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
@ -433,15 +503,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
if self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
|
||||
key = src
|
||||
val = src
|
||||
if not self.training and states is not None:
|
||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||
key = torch.cat([states[0], src], dim=0)
|
||||
val = key
|
||||
states[0] = key[-left_context:, ...]
|
||||
else:
|
||||
assert left_context == 0
|
||||
key = torch.cat([states[0], src], dim=0)
|
||||
val = key
|
||||
states[0] = key[-left_context:, ...]
|
||||
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
@ -461,11 +525,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
|
||||
if not self.training and states is not None:
|
||||
src, conv_cache = self.conv_module(src, states[1])
|
||||
states[1] = conv_cache
|
||||
else:
|
||||
src = self.conv_module(src)
|
||||
src, conv_cache = self.conv_module(src, states[1])
|
||||
states[1] = conv_cache
|
||||
src = residual + self.dropout(src)
|
||||
|
||||
if not self.normalize_before:
|
||||
@ -513,8 +574,6 @@ class ConformerEncoder(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
states: Optional[List[Tensor]] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -523,21 +582,11 @@ class ConformerEncoder(nn.Module):
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: If not None, states will be modified in this function.
|
||||
left_context: left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
Shape:
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||
pos_emb: (N, 2*S-1, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
@ -545,29 +594,65 @@ class ConformerEncoder(nn.Module):
|
||||
"""
|
||||
output = src
|
||||
|
||||
if self.training:
|
||||
assert left_context == 0
|
||||
assert states is None
|
||||
else:
|
||||
assert left_context >= 0
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
cache = (
|
||||
None
|
||||
if states is None
|
||||
else [states[0][layer_index], states[1][layer_index]]
|
||||
)
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
return output
|
||||
|
||||
@torch.jit.export
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
states: List[Tensor],
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
left_context: left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
assert not self.training
|
||||
output = src
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
cache = [states[0][layer_index], states[1][layer_index]]
|
||||
output = mod.chunk_forward(
|
||||
output,
|
||||
pos_emb,
|
||||
states=cache,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
left_context=left_context,
|
||||
)
|
||||
if states is not None:
|
||||
states[0][layer_index] = cache[0]
|
||||
states[1][layer_index] = cache[1]
|
||||
states[0][layer_index] = cache[0]
|
||||
states[1][layer_index] = cache[1]
|
||||
|
||||
return output
|
||||
|
||||
@ -1186,7 +1271,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, cache: Optional[Tensor] = None
|
||||
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
@ -1227,9 +1312,10 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return (
|
||||
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
|
||||
)
|
||||
if cache is None:
|
||||
cache = torch.empty(0)
|
||||
|
||||
return x.permute(2, 0, 1), cache
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user