mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Add torch.jit.export
This commit is contained in:
parent
605838da55
commit
0325e3a04e
@ -375,9 +375,10 @@ def decode_one_batch(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
states=[],
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.right_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
|
@ -109,6 +109,47 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
Note: not needed here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
Note: not needed for here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""How many left context can be seen in chunks when calculating attention.
|
||||||
|
Note: not needed here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -130,6 +171,7 @@ def main():
|
|||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
@ -288,7 +288,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
@ -157,7 +157,6 @@ class Conformer(EncoderInterface):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
mask = None
|
|
||||||
|
|
||||||
if self.dynamic_chunk_training:
|
if self.dynamic_chunk_training:
|
||||||
assert (
|
assert (
|
||||||
@ -176,24 +175,32 @@ class Conformer(EncoderInterface):
|
|||||||
num_left_chunks=self.num_left_chunks,
|
num_left_chunks=self.num_left_chunks,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
|
x = self.encoder(
|
||||||
x, _ = self.encoder(
|
x,
|
||||||
x,
|
pos_emb,
|
||||||
pos_emb,
|
mask=mask,
|
||||||
mask=mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
warmup=warmup,
|
||||||
warmup=warmup,
|
) # (T, N, C)
|
||||||
) # (T, N, C)
|
else:
|
||||||
|
x = self.encoder(
|
||||||
|
x,
|
||||||
|
pos_emb,
|
||||||
|
mask=None,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def streaming_forward(
|
def streaming_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
|
states: List[Tensor],
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
states: Optional[List[Tensor]] = None,
|
|
||||||
chunk_size: int = 16,
|
chunk_size: int = 16,
|
||||||
left_context: int = 64,
|
left_context: int = 64,
|
||||||
simulate_streaming: bool = False,
|
simulate_streaming: bool = False,
|
||||||
@ -205,17 +212,17 @@ class Conformer(EncoderInterface):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
`x` before padding.
|
`x` before padding.
|
||||||
warmup:
|
|
||||||
A floating point value that gradually increases from 0 throughout
|
|
||||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
|
||||||
to turn modules on sequentially.
|
|
||||||
states:
|
states:
|
||||||
The decode states for previous frames which contains the cached data.
|
The decode states for previous frames which contains the cached data.
|
||||||
It has two elements, the first element is the attn_cache which has
|
It has two elements, the first element is the attn_cache which has
|
||||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||||
the second element is the conv_cache which has a shape of
|
the second element is the conv_cache which has a shape of
|
||||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||||
Note: If not None, states will be modified in this function.
|
Note: states will be modified in this function.
|
||||||
|
warmup:
|
||||||
|
A floating point value that gradually increases from 0 throughout
|
||||||
|
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||||
|
to turn modules on sequentially.
|
||||||
chunk_size:
|
chunk_size:
|
||||||
The chunk size for decoding, this will be used to simulate streaming
|
The chunk size for decoding, this will be used to simulate streaming
|
||||||
decoding using masking.
|
decoding using masking.
|
||||||
@ -245,10 +252,6 @@ class Conformer(EncoderInterface):
|
|||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
if not simulate_streaming:
|
if not simulate_streaming:
|
||||||
assert (
|
|
||||||
states is not None
|
|
||||||
), "Require cache when sending data in streaming mode"
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(states) == 2
|
len(states) == 2
|
||||||
and states[0].shape
|
and states[0].shape
|
||||||
@ -272,7 +275,7 @@ class Conformer(EncoderInterface):
|
|||||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
|
||||||
x = self.encoder(
|
x = self.encoder.chunk_forward(
|
||||||
embed,
|
embed,
|
||||||
pos_enc,
|
pos_enc,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
@ -282,7 +285,6 @@ class Conformer(EncoderInterface):
|
|||||||
) # (T, B, F)
|
) # (T, B, F)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert states is None
|
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
@ -392,9 +394,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
states: Optional[List[Tensor]] = None,
|
) -> Tensor:
|
||||||
left_context: int = 0,
|
|
||||||
) -> Tuple[Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
@ -405,20 +405,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
bypass layers more frequently.
|
bypass layers more frequently.
|
||||||
states:
|
|
||||||
The decode states for previous frames which contains the cached data.
|
|
||||||
It has two elements, the first element is the attn_cache which has
|
|
||||||
a shape of (left_context, batch, attention_dim),
|
|
||||||
the second element is the conv_cache which has a shape of
|
|
||||||
(cnn_module_kernel-1, batch, conv_dim).
|
|
||||||
Note: If not None, states will be modified in this function.
|
|
||||||
left_context: left context (in frames) used during streaming decoding.
|
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
|
||||||
it MUST be 0.
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*S-1, E),for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
pos_emb: (N, 2*S-1, E)
|
||||||
src_mask: (S, S).
|
src_mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
@ -440,15 +429,82 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
key = src
|
# multi-headed self-attention module
|
||||||
val = src
|
src_att = self.self_attn(
|
||||||
if not self.training and states is not None:
|
src,
|
||||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
src,
|
||||||
key = torch.cat([states[0], src], dim=0)
|
src,
|
||||||
val = key
|
pos_emb=pos_emb,
|
||||||
states[0] = key[-left_context:, ...]
|
attn_mask=src_mask,
|
||||||
else:
|
key_padding_mask=src_key_padding_mask,
|
||||||
assert left_context == 0
|
)[0]
|
||||||
|
|
||||||
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
conv, _ = self.conv_module(src)
|
||||||
|
src = src + self.dropout(conv)
|
||||||
|
|
||||||
|
# feed forward module
|
||||||
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
|
|
||||||
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
|
if alpha != 1.0:
|
||||||
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
states: List[Tensor],
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder layer (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
states:
|
||||||
|
The decode states for previous frames which contains the cached data.
|
||||||
|
It has two elements, the first element is the attn_cache which has
|
||||||
|
a shape of (left_context, batch, attention_dim),
|
||||||
|
the second element is the conv_cache which has a shape of
|
||||||
|
(cnn_module_kernel-1, batch, conv_dim).
|
||||||
|
Note: states will be modified in this function.
|
||||||
|
src_mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
|
bypass layers more frequently.
|
||||||
|
left_context: left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||||
|
src_mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert not self.training
|
||||||
|
assert len(states) == 2
|
||||||
|
assert states[0].shape == (left_context, src.size(1), src.size(2))
|
||||||
|
|
||||||
|
# macaron style feed forward module
|
||||||
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
|
key = torch.cat([states[0], src], dim=0)
|
||||||
|
val = key
|
||||||
|
states[0] = key[-left_context:, ...]
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
@ -464,11 +520,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
if not self.training and states is not None:
|
conv, conv_cache = self.conv_module(src, states[1])
|
||||||
conv, conv_cache = self.conv_module(src, states[1])
|
states[1] = conv_cache
|
||||||
states[1] = conv_cache
|
|
||||||
else:
|
|
||||||
conv = self.conv_module(src)
|
|
||||||
src = src + self.dropout(conv)
|
src = src + self.dropout(conv)
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
@ -476,9 +530,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
if alpha != 1.0:
|
|
||||||
src = alpha * src + (1 - alpha) * src_orig
|
|
||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
|
||||||
@ -511,8 +562,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
states: Optional[List[Tensor]] = None,
|
|
||||||
left_context: int = 0,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -523,20 +572,10 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
bypass layers more frequently.
|
bypass layers more frequently.
|
||||||
states:
|
|
||||||
The decode states for previous frames which contains the cached data.
|
|
||||||
It has two elements, the first element is the attn_cache which has
|
|
||||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
|
||||||
the second element is the conv_cache which has a shape of
|
|
||||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
|
||||||
Note: If not None, states will be modified in this function.
|
|
||||||
left_context: left context (in frames) used during streaming decoding.
|
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
|
||||||
it MUST be 0.
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
pos_emb: (N, 2*S-1, E)
|
||||||
mask: (S, S).
|
mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
@ -544,30 +583,81 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
if self.training:
|
|
||||||
assert left_context == 0
|
|
||||||
assert states is None
|
|
||||||
else:
|
|
||||||
assert left_context >= 0
|
|
||||||
|
|
||||||
for layer_index, mod in enumerate(self.layers):
|
for layer_index, mod in enumerate(self.layers):
|
||||||
cache = (
|
|
||||||
None
|
|
||||||
if states is None
|
|
||||||
else [states[0][layer_index], states[1][layer_index]]
|
|
||||||
)
|
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
states: List[Tensor],
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
states:
|
||||||
|
The decode states for previous frames which contains the cached data.
|
||||||
|
It has two elements, the first element is the attn_cache which has
|
||||||
|
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||||
|
the second element is the conv_cache which has a shape of
|
||||||
|
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||||
|
Note: states will be modified in this function.
|
||||||
|
mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
|
bypass layers more frequently.
|
||||||
|
left_context: left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||||
|
mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert not self.training
|
||||||
|
assert len(states) == 2
|
||||||
|
assert states[0].shape == (
|
||||||
|
len(self.layers),
|
||||||
|
left_context,
|
||||||
|
src.size(1),
|
||||||
|
src.size(2),
|
||||||
|
)
|
||||||
|
assert states[1].size(0) == len(self.layers)
|
||||||
|
|
||||||
|
output = src
|
||||||
|
|
||||||
|
for layer_index, mod in enumerate(self.layers):
|
||||||
|
cache = [states[0][layer_index], states[1][layer_index]]
|
||||||
|
output = mod.chunk_forward(
|
||||||
|
output,
|
||||||
|
pos_emb,
|
||||||
states=cache,
|
states=cache,
|
||||||
|
src_mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
left_context=left_context,
|
left_context=left_context,
|
||||||
)
|
)
|
||||||
if states is not None:
|
states[0][layer_index] = cache[0]
|
||||||
states[0][layer_index] = cache[0]
|
states[1][layer_index] = cache[1]
|
||||||
states[1][layer_index] = cache[1]
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -1216,7 +1306,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
cache: Optional[Tensor] = None,
|
cache: Optional[Tensor] = None,
|
||||||
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1260,9 +1350,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
return (
|
# torch.jit.script requires return types be the same as annotated above
|
||||||
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
|
if cache is None:
|
||||||
)
|
cache = torch.empty(0)
|
||||||
|
|
||||||
|
return x.permute(2, 0, 1), cache
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
@ -335,9 +335,10 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
|
states=[],
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.right_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
|
@ -124,6 +124,47 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
Note: not needed here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
Note: not needed for here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""How many left context can be seen in chunks when calculating attention.
|
||||||
|
Note: not needed here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -155,7 +155,6 @@ class Conformer(Transformer):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
mask = None
|
|
||||||
|
|
||||||
if self.dynamic_chunk_training:
|
if self.dynamic_chunk_training:
|
||||||
assert (
|
assert (
|
||||||
@ -174,10 +173,13 @@ class Conformer(Transformer):
|
|||||||
num_left_chunks=self.num_left_chunks,
|
num_left_chunks=self.num_left_chunks,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
|
x = self.encoder(
|
||||||
x, _ = self.encoder(
|
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
) # (T, N, C)
|
||||||
) # (T, N, C)
|
else:
|
||||||
|
x = self.encoder(
|
||||||
|
x, pos_emb, mask=None, src_key_padding_mask=src_key_padding_mask
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
@ -187,11 +189,12 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
return logits, lengths
|
return logits, lengths
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def streaming_forward(
|
def streaming_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
states: Optional[List[torch.Tensor]] = None,
|
states: List[torch.Tensor],
|
||||||
chunk_size: int = 16,
|
chunk_size: int = 16,
|
||||||
left_context: int = 64,
|
left_context: int = 64,
|
||||||
simulate_streaming: bool = False,
|
simulate_streaming: bool = False,
|
||||||
@ -209,7 +212,7 @@ class Conformer(Transformer):
|
|||||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||||
the second element is the conv_cache which has a shape of
|
the second element is the conv_cache which has a shape of
|
||||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||||
Note: If not None, states will be modified in this function.
|
Note: states will be modified in this function.
|
||||||
chunk_size:
|
chunk_size:
|
||||||
The chunk size for decoding, this will be used to simulate streaming
|
The chunk size for decoding, this will be used to simulate streaming
|
||||||
decoding using masking.
|
decoding using masking.
|
||||||
@ -239,10 +242,6 @@ class Conformer(Transformer):
|
|||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
|
|
||||||
if not simulate_streaming:
|
if not simulate_streaming:
|
||||||
assert (
|
|
||||||
states is not None
|
|
||||||
), "Require cache when sending data in streaming mode"
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(states) == 2
|
len(states) == 2
|
||||||
and states[0].shape
|
and states[0].shape
|
||||||
@ -266,17 +265,14 @@ class Conformer(Transformer):
|
|||||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
|
||||||
x = self.encoder(
|
x = self.encoder.chunk_forward(
|
||||||
embed,
|
embed,
|
||||||
pos_enc,
|
pos_enc,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
states=states,
|
states=states,
|
||||||
left_context=left_context,
|
left_context=left_context,
|
||||||
) # (T, B, F)
|
) # (T, B, F)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert states is None
|
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
@ -389,8 +385,6 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
states: Optional[List[Tensor]] = None,
|
|
||||||
left_context: int = 0,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
@ -400,19 +394,95 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*S-1, E).
|
||||||
|
src_mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
|
"""
|
||||||
|
# macaron style feed forward module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_ff_macaron(src)
|
||||||
|
src = residual + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(src)
|
||||||
|
)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
|
# multi-headed self-attention module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
|
src_att = self.self_attn(
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
attn_mask=src_mask,
|
||||||
|
key_padding_mask=src_key_padding_mask,
|
||||||
|
)[0]
|
||||||
|
src = residual + self.dropout(src_att)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_conv(src)
|
||||||
|
|
||||||
|
src, _ = self.conv_module(src)
|
||||||
|
src = residual + self.dropout(src)
|
||||||
|
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_conv(src)
|
||||||
|
|
||||||
|
# feed forward module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_ff(src)
|
||||||
|
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_ff(src)
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_final(src)
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
states: List[Tensor],
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder layer (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
states:
|
states:
|
||||||
The decode states for previous frames which contains the cached data.
|
The decode states for previous frames which contains the cached data.
|
||||||
It has two elements, the first element is the attn_cache which has
|
It has two elements, the first element is the attn_cache which has
|
||||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||||
the second element is the conv_cache which has a shape of
|
the second element is the conv_cache which has a shape of
|
||||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||||
Note: If not None, states will be modified in this function.
|
Note: states will be modified in this function.
|
||||||
|
src_mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
left_context: left context (in frames) used during streaming decoding.
|
left_context: left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||||
src_mask: (S, S).
|
src_mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
@ -433,15 +503,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_mha(src)
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
key = src
|
key = torch.cat([states[0], src], dim=0)
|
||||||
val = src
|
val = key
|
||||||
if not self.training and states is not None:
|
states[0] = key[-left_context:, ...]
|
||||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
|
||||||
key = torch.cat([states[0], src], dim=0)
|
|
||||||
val = key
|
|
||||||
states[0] = key[-left_context:, ...]
|
|
||||||
else:
|
|
||||||
assert left_context == 0
|
|
||||||
|
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
src,
|
src,
|
||||||
@ -461,11 +525,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_conv(src)
|
src = self.norm_conv(src)
|
||||||
|
|
||||||
if not self.training and states is not None:
|
src, conv_cache = self.conv_module(src, states[1])
|
||||||
src, conv_cache = self.conv_module(src, states[1])
|
states[1] = conv_cache
|
||||||
states[1] = conv_cache
|
|
||||||
else:
|
|
||||||
src = self.conv_module(src)
|
|
||||||
src = residual + self.dropout(src)
|
src = residual + self.dropout(src)
|
||||||
|
|
||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
@ -513,8 +574,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
states: Optional[List[Tensor]] = None,
|
|
||||||
left_context: int = 0,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -523,21 +582,11 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
mask: the mask for the src sequence (optional).
|
mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
states:
|
|
||||||
The decode states for previous frames which contains the cached data.
|
|
||||||
It has two elements, the first element is the attn_cache which has
|
|
||||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
|
||||||
the second element is the conv_cache which has a shape of
|
|
||||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
|
||||||
Note: If not None, states will be modified in this function.
|
|
||||||
left_context: left context (in frames) used during streaming decoding.
|
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
|
||||||
it MUST be 0.
|
|
||||||
Shape:
|
Shape:
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
pos_emb: (N, 2*S-1, E).
|
||||||
mask: (S, S).
|
mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
@ -545,29 +594,65 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
if self.training:
|
|
||||||
assert left_context == 0
|
|
||||||
assert states is None
|
|
||||||
else:
|
|
||||||
assert left_context >= 0
|
|
||||||
|
|
||||||
for layer_index, mod in enumerate(self.layers):
|
for layer_index, mod in enumerate(self.layers):
|
||||||
cache = (
|
|
||||||
None
|
|
||||||
if states is None
|
|
||||||
else [states[0][layer_index], states[1][layer_index]]
|
|
||||||
)
|
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
states: List[Tensor],
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
states:
|
||||||
|
The decode states for previous frames which contains the cached data.
|
||||||
|
It has two elements, the first element is the attn_cache which has
|
||||||
|
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||||
|
the second element is the conv_cache which has a shape of
|
||||||
|
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||||
|
Note: states will be modified in this function.
|
||||||
|
mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
left_context: left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||||
|
mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert not self.training
|
||||||
|
output = src
|
||||||
|
|
||||||
|
for layer_index, mod in enumerate(self.layers):
|
||||||
|
cache = [states[0][layer_index], states[1][layer_index]]
|
||||||
|
output = mod.chunk_forward(
|
||||||
|
output,
|
||||||
|
pos_emb,
|
||||||
states=cache,
|
states=cache,
|
||||||
|
src_mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
left_context=left_context,
|
left_context=left_context,
|
||||||
)
|
)
|
||||||
if states is not None:
|
states[0][layer_index] = cache[0]
|
||||||
states[0][layer_index] = cache[0]
|
states[1][layer_index] = cache[1]
|
||||||
states[1][layer_index] = cache[1]
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -1186,7 +1271,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, cache: Optional[Tensor] = None
|
self, x: Tensor, cache: Optional[Tensor] = None
|
||||||
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1227,9 +1312,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
return (
|
if cache is None:
|
||||||
x.permute(2, 0, 1) if cache is None else (x.permute(2, 0, 1), cache)
|
cache = torch.empty(0)
|
||||||
)
|
|
||||||
|
return x.permute(2, 0, 1), cache
|
||||||
|
|
||||||
|
|
||||||
class Swish(torch.nn.Module):
|
class Swish(torch.nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user