mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
streaming conformer pruned transducer stateless
This commit is contained in:
parent
e03d237f9a
commit
ee359f4d13
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless/beam_search.py
|
@ -15,7 +15,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@ -24,7 +25,7 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformer import Transformer
|
from transformer import Transformer
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
@ -56,6 +57,12 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
|
dynamic_chunk_training: bool = True,
|
||||||
|
short_chunk_threshold: float = 0.75,
|
||||||
|
causal: bool = True,
|
||||||
|
short_chunk_size: int = 25,
|
||||||
|
use_codebook_loss: bool = False,
|
||||||
|
num_codebooks: int = 4,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -69,6 +76,9 @@ class Conformer(Transformer):
|
|||||||
normalize_before=normalize_before,
|
normalize_before=normalize_before,
|
||||||
vgg_frontend=vgg_frontend,
|
vgg_frontend=vgg_frontend,
|
||||||
)
|
)
|
||||||
|
self.dynamic_chunk_training = dynamic_chunk_training
|
||||||
|
self.short_chunk_threshold = short_chunk_threshold
|
||||||
|
self.short_chunk_size = short_chunk_size
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
@ -79,6 +89,7 @@ class Conformer(Transformer):
|
|||||||
dropout,
|
dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
normalize_before,
|
normalize_before,
|
||||||
|
causal,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
@ -112,9 +123,176 @@ class Conformer(Transformer):
|
|||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
|
if self.dynamic_chunk_training:
|
||||||
|
max_len = x.size(0)
|
||||||
|
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||||
|
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||||
|
chunk_size = max_len
|
||||||
|
else:
|
||||||
|
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||||
|
|
||||||
|
mask = ~subsequent_chunk_mask(
|
||||||
|
size=x.size(0), chunk_size=chunk_size, device=x.device
|
||||||
|
)
|
||||||
|
x = self.encoder(
|
||||||
|
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||||
|
) # (T, B, F)
|
||||||
|
else:
|
||||||
|
x = self.encoder(
|
||||||
|
x, pos_emb, src_key_padding_mask=src_key_padding_mask
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.after_norm(x)
|
||||||
|
|
||||||
|
logits = self.encoder_output_layer(x)
|
||||||
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
return logits, lengths
|
||||||
|
|
||||||
|
def streaming_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
chunk_size: int = 16,
|
||||||
|
simulate_streaming: bool = True,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# x: [N, T, C]
|
||||||
|
|
||||||
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
|
if chunk_size < 0:
|
||||||
|
# Deocding with full-right context.
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
x = self.encoder(
|
||||||
|
x, pos_emb, src_key_padding_mask=src_key_padding_mask
|
||||||
|
) # (T, B, F)
|
||||||
|
else:
|
||||||
|
# As temporarily in icefall only subsampling_rate == 4 is supported,
|
||||||
|
# following parameters are hard-coded here.
|
||||||
|
# Change it accordingly if other subsamling_rate are supported.
|
||||||
|
# The first frame encoder_out needs at least 7 frames fbank feature
|
||||||
|
embed_left_context = 7
|
||||||
|
# Each successive frame needs 4 cached frames fbank feature
|
||||||
|
subsampling_rate = 4
|
||||||
|
# So the extra frames needed to generate the first frame encoder_out:
|
||||||
|
embed_conv_context = embed_left_context - subsampling_rate
|
||||||
|
|
||||||
|
stride = chunk_size * subsampling_rate
|
||||||
|
decoding_window = embed_conv_context + stride
|
||||||
|
if simulate_streaming:
|
||||||
|
# simulate chunk_by_chunk streaming decoding
|
||||||
|
# Results of this branch should be identical to following
|
||||||
|
# "else" branch.
|
||||||
|
# But this branch is a little slower
|
||||||
|
# as the feature is feeded chunk by chunk
|
||||||
|
|
||||||
|
# store the result of chunk_by_chunk decoding
|
||||||
|
encoder_output = []
|
||||||
|
|
||||||
|
# caches
|
||||||
|
pos_emb_positive = []
|
||||||
|
pos_emb_negative = []
|
||||||
|
pos_emb_central = None
|
||||||
|
encoder_cache = [None for i in range(len(self.encoder.layers))]
|
||||||
|
conv_cache = [None for i in range(len(self.encoder.layers))]
|
||||||
|
|
||||||
|
# start chunk_by_chunk decoding
|
||||||
|
offset = 0
|
||||||
|
feature = x
|
||||||
|
num_frames = feature.size(1)
|
||||||
|
for cur in range(
|
||||||
|
0, num_frames - embed_left_context + 1, stride
|
||||||
|
):
|
||||||
|
end = min(cur + decoding_window, num_frames)
|
||||||
|
cur_feature = feature[:, cur:end, :]
|
||||||
|
cur_feature = self.encoder_embed(cur_feature)
|
||||||
|
cur_embed, cur_pos_emb = self.encoder_pos(
|
||||||
|
cur_feature, offset
|
||||||
|
)
|
||||||
|
cur_embed = cur_embed.permute(
|
||||||
|
1, 0, 2
|
||||||
|
) # (B, T, F) -> (T, B, F)
|
||||||
|
|
||||||
|
cur_T = cur_feature.size(1)
|
||||||
|
if cur == 0:
|
||||||
|
real_chunk_size = min(cur_T, chunk_size)
|
||||||
|
assert (
|
||||||
|
cur_pos_emb.size(1) == 2 * real_chunk_size - 1
|
||||||
|
), f"{cur_pos_emb.size(1)} == 2 * {real_chunk_size} - 1"
|
||||||
|
|
||||||
|
# Extract the central pos embedding during first chunk
|
||||||
|
pos_emb_central = cur_pos_emb[
|
||||||
|
0, (real_chunk_size - 1), :
|
||||||
|
].view(1, 1, -1)
|
||||||
|
cur_T -= 1
|
||||||
|
|
||||||
|
# first chunk with chunk_size > 1
|
||||||
|
# or not first chunk
|
||||||
|
if (cur_T > 1 and cur == 0) or cur != 0:
|
||||||
|
pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
|
||||||
|
pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
|
||||||
|
|
||||||
|
assert pos_emb_positive[-1].size(0) == cur_T
|
||||||
|
|
||||||
|
pos_emb_pos = torch.cat(
|
||||||
|
pos_emb_positive, dim=0
|
||||||
|
).unsqueeze(0)
|
||||||
|
pos_emb_neg = torch.cat(
|
||||||
|
pos_emb_negative, dim=0
|
||||||
|
).unsqueeze(0)
|
||||||
|
cur_pos_emb = torch.cat(
|
||||||
|
[pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.encoder.chunk_forward(
|
||||||
|
cur_embed,
|
||||||
|
cur_pos_emb,
|
||||||
|
src_key_padding_mask=src_key_padding_mask[
|
||||||
|
:, : offset + cur_embed.size(0)
|
||||||
|
],
|
||||||
|
encoder_cache=encoder_cache,
|
||||||
|
conv_cache=conv_cache,
|
||||||
|
offset=offset,
|
||||||
|
) # (T, B, F)
|
||||||
|
encoder_output.append(x)
|
||||||
|
offset += cur_embed.size(0)
|
||||||
|
|
||||||
|
assert num_frames - end <= 3
|
||||||
|
if num_frames != end:
|
||||||
|
logging.info(
|
||||||
|
f"The tailing {num_frames - end} frames fbank are not deocded."
|
||||||
|
)
|
||||||
|
x = torch.cat(encoder_output, dim=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# NOT simulate chunk_by_chunk decoding
|
||||||
|
# Results of this branch should be identical to previous
|
||||||
|
# simulate chunk_by_chunk decoding branch.
|
||||||
|
# But this branch is faster.
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
assert x.size(0) == lengths.max().item()
|
||||||
|
mask = ~subsequent_chunk_mask(
|
||||||
|
size=x.size(0), chunk_size=chunk_size, device=x.device
|
||||||
|
)
|
||||||
|
x = self.encoder(
|
||||||
|
x,
|
||||||
|
pos_emb,
|
||||||
|
mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
@ -153,6 +331,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
|
causal: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
@ -173,7 +352,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(
|
||||||
|
d_model, cnn_module_kernel, causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
self.norm_ff_macaron = nn.LayerNorm(
|
self.norm_ff_macaron = nn.LayerNorm(
|
||||||
d_model
|
d_model
|
||||||
@ -263,13 +444,105 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
encoder_cache: Optional[Tensor] = None,
|
||||||
|
conv_cache: Optional[Tensor] = None,
|
||||||
|
offset=0,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
class ConformerEncoder(nn.Module):
|
Args:
|
||||||
|
src: the sequence to the encoder layer (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
src_mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*S-1, E)
|
||||||
|
src_mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
|
"""
|
||||||
|
|
||||||
|
# macaron style feed forward module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_ff_macaron(src)
|
||||||
|
src = residual + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(src)
|
||||||
|
)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
|
# multi-headed self-attention module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
if encoder_cache is None:
|
||||||
|
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||||
|
key = src
|
||||||
|
val = key
|
||||||
|
encoder_cache = key
|
||||||
|
else:
|
||||||
|
key = torch.cat([encoder_cache, src], dim=0)
|
||||||
|
val = key
|
||||||
|
encoder_cache = key
|
||||||
|
src_att = self.self_attn(
|
||||||
|
src,
|
||||||
|
key,
|
||||||
|
val,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
attn_mask=src_mask,
|
||||||
|
key_padding_mask=src_key_padding_mask,
|
||||||
|
offset=offset,
|
||||||
|
)[0]
|
||||||
|
src = residual + self.dropout(src_att)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
residual = src # [chunk_size, N, F] e.g. [8, 41, 512]
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_conv(src)
|
||||||
|
if conv_cache is not None:
|
||||||
|
src = torch.cat([conv_cache, src], dim=0)
|
||||||
|
conv_cache = src
|
||||||
|
|
||||||
|
src = self.conv_module(src)
|
||||||
|
src = src[-residual.size(0) :, :, :] # noqa: E203
|
||||||
|
|
||||||
|
src = residual + self.dropout(src)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_conv(src)
|
||||||
|
|
||||||
|
# feed forward module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_ff(src)
|
||||||
|
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_ff(src)
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_final(src)
|
||||||
|
|
||||||
|
return src, encoder_cache, conv_cache
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerEncoder(nn.TransformerEncoder):
|
||||||
r"""ConformerEncoder is a stack of N encoder layers
|
r"""ConformerEncoder is a stack of N encoder layers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
|
norm: the layer normalization component (optional).
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||||
@ -279,10 +552,11 @@ class ConformerEncoder(nn.Module):
|
|||||||
>>> out = conformer_encoder(src, pos_emb)
|
>>> out = conformer_encoder(src, pos_emb)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
def __init__(
|
||||||
super().__init__()
|
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
|
||||||
self.layers = nn.ModuleList(
|
) -> None:
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
super(ConformerEncoder, self).__init__(
|
||||||
|
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
|
||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
@ -319,6 +593,55 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
output = self.norm(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
encoder_cache=None,
|
||||||
|
conv_cache=None,
|
||||||
|
offset=0,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*S-1, E)
|
||||||
|
mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
|
"""
|
||||||
|
output = src
|
||||||
|
|
||||||
|
for layer_index, mod in enumerate(self.layers):
|
||||||
|
output, e_cache, c_cache = mod.chunk_forward(
|
||||||
|
output,
|
||||||
|
pos_emb,
|
||||||
|
src_mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
encoder_cache=encoder_cache[layer_index],
|
||||||
|
conv_cache=conv_cache[layer_index],
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
encoder_cache[layer_index] = e_cache
|
||||||
|
conv_cache[layer_index] = c_cache
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
output = self.norm(output)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -346,12 +669,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor) -> None:
|
def extend_pe(self, x: Tensor, offset: int = 0) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
|
x_size_1 = offset + x.size(1)
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
if self.pe.size(1) >= x_size_1 * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
x.device
|
x.device
|
||||||
@ -361,9 +685,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||||
# position of key vector. We use position relative positions when keys
|
# position of key vector. We use position relative positions when keys
|
||||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
pe_positive = torch.zeros(x_size_1, self.d_model)
|
||||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
pe_negative = torch.zeros(x_size_1, self.d_model)
|
||||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
|
||||||
div_term = torch.exp(
|
div_term = torch.exp(
|
||||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
* -(math.log(10000.0) / self.d_model)
|
* -(math.log(10000.0) / self.d_model)
|
||||||
@ -381,26 +705,35 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
def forward(
|
||||||
|
self, x: torch.Tensor, offset: int = 0
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
|
offset: time-index of the first frame of x.
|
||||||
|
used to compute positional encoding in a streaming fasion.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x, offset)
|
||||||
x = x * self.xscale
|
x = x * self.xscale
|
||||||
|
x_size_1 = offset + x.size(1)
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
- x.size(1)
|
- x_size_1
|
||||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
+ x.size(1),
|
+ x_size_1,
|
||||||
]
|
]
|
||||||
|
x_T = x.size(1)
|
||||||
|
if offset > 0:
|
||||||
|
pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1)
|
||||||
|
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
@ -464,6 +797,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
offset=0,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -522,9 +856,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
offset=offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
def rel_shift(self, x: Tensor) -> Tensor:
|
def rel_shift(self, x: Tensor, offset=0) -> Tensor:
|
||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -533,18 +868,20 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: tensor of shape (batch, head, time1, time2)
|
Tensor: tensor of shape (batch, head, time1, time2)
|
||||||
(note: time2 has the same value as time1, but it is for
|
(note: time2 == time1 + offset, since it is for
|
||||||
the key, while time1 is for the query).
|
the key, while time1 is for the query).
|
||||||
"""
|
"""
|
||||||
(batch_size, num_heads, time1, n) = x.shape
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
assert n == 2 * time1 - 1
|
time2 = time1 + offset
|
||||||
|
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
batch_stride = x.stride(0)
|
batch_stride = x.stride(0)
|
||||||
head_stride = x.stride(1)
|
head_stride = x.stride(1)
|
||||||
time1_stride = x.stride(2)
|
time1_stride = x.stride(2)
|
||||||
n_stride = x.stride(3)
|
n_stride = x.stride(3)
|
||||||
|
|
||||||
return x.as_strided(
|
return x.as_strided(
|
||||||
(batch_size, num_heads, time1, time1),
|
(batch_size, num_heads, time1, time2),
|
||||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||||
storage_offset=n_stride * (time1 - 1),
|
storage_offset=n_stride * (time1 - 1),
|
||||||
)
|
)
|
||||||
@ -566,6 +903,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
offset=0,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -639,7 +977,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim
|
_start = embed_dim
|
||||||
@ -745,7 +1082,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
pos_emb_bsz = pos_emb.size(0)
|
pos_emb_bsz = pos_emb.size(0)
|
||||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
|
||||||
|
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||||
|
p = p.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
@ -765,10 +1104,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
# compute matrix b and matrix d
|
# compute matrix b and matrix d
|
||||||
matrix_bd = torch.matmul(
|
matrix_bd = torch.matmul(
|
||||||
q_with_bias_v, p.transpose(-2, -1)
|
q_with_bias_v, p
|
||||||
) # (batch, head, time1, 2*time1-1)
|
) # (batch, head, time1, 2*time1-1)
|
||||||
matrix_bd = self.rel_shift(matrix_bd)
|
matrix_bd = self.rel_shift(
|
||||||
|
matrix_bd, offset=offset
|
||||||
|
) # [B, head, time1, time2]
|
||||||
attn_output_weights = (
|
attn_output_weights = (
|
||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
) * scaling # (batch, head, time1, time2)
|
) * scaling # (batch, head, time1, time2)
|
||||||
@ -835,11 +1175,16 @@ class ConvolutionModule(nn.Module):
|
|||||||
channels (int): The number of channels of conv layers.
|
channels (int): The number of channels of conv layers.
|
||||||
kernel_size (int): Kernerl size of conv layers.
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
causal (bool): Whether to use causal convlution (default=True).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, channels: int, kernel_size: int, bias: bool = True
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
causal: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
@ -854,12 +1199,20 @@ class ConvolutionModule(nn.Module):
|
|||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
causal
|
||||||
|
), "Currently, causal convolution is required for streaming conformer."
|
||||||
|
|
||||||
|
# Manualy padding self.lorder zeros to the left during forward.
|
||||||
|
self.lorder = kernel_size - 1
|
||||||
|
padding = 0
|
||||||
|
|
||||||
self.depthwise_conv = nn.Conv1d(
|
self.depthwise_conv = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=(kernel_size - 1) // 2,
|
padding=padding,
|
||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
@ -892,6 +1245,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
|
if self.lorder > 0:
|
||||||
|
# Make depthwise_conv causal by
|
||||||
|
# manualy padding self.lorder zeros to the left
|
||||||
|
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
# x is (batch, channels, time)
|
# x is (batch, channels, time)
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
|
@ -18,18 +18,22 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless/decode.py \
|
./streaming_pruned_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--simulate-streaming [True|False] \
|
||||||
|
--right-chunk-size [1/4/8/16/32/-1] \
|
||||||
|
--epoch 49 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./streaming_pruned_transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless/decode.py \
|
./streaming_pruned_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--simulate-streaming [True|False] \
|
||||||
|
--right-chunk-size [1/4/8/16/32/-1] \
|
||||||
|
--epoch 49 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./streaming_pruned_transducer_stateless/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
@ -38,6 +42,7 @@ Usage:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
@ -52,12 +57,17 @@ from decoder import Decoder
|
|||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
save_checkpoint,
|
||||||
|
)
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -67,6 +77,29 @@ def get_parser():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to split fbanks into chunks to simulate forward conformer"
|
||||||
|
"in a streaming fashion",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tailing-dummy-frames",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="tailing dummy frames padded to the right,"
|
||||||
|
"only used during decoding",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--right-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="right context to attend during decoding",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
@ -86,7 +119,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless/exp",
|
default="streaming_pruned_transducer_stateless/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -145,6 +178,8 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
|
# model average
|
||||||
|
"save_averaged_model": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -236,10 +271,26 @@ def decode_one_batch(
|
|||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
# Extra dummy tailing frames my reduce deletion error
|
||||||
|
# example WITHOUT padding:
|
||||||
|
# CHAPTER SEVEN ON THE RACES OF MAN
|
||||||
|
# example WITH padding:
|
||||||
|
# CHAPTER SEVEN ON THE RACES OF (MAN->*)
|
||||||
|
tailing_frames = (
|
||||||
|
torch.tensor([-23.0259])
|
||||||
|
.expand([feature.size(0), params.tailing_dummy_frames, 80])
|
||||||
|
.to(feature.device)
|
||||||
|
)
|
||||||
|
feature = torch.cat([feature, tailing_frames], dim=1)
|
||||||
|
supervisions["num_frames"] += params.tailing_dummy_frames
|
||||||
|
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
chunk_size=params.right_chunk_size,
|
||||||
|
simulate_streaming=params.simulate_streaming,
|
||||||
)
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
@ -395,6 +446,9 @@ def main():
|
|||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
params.suffix += f"-chunk_size-{params.right_chunk_size}"
|
||||||
|
params.suffix += f"-{params.simulate_streaming}"
|
||||||
|
params.suffix += f"-tailing-dummy-frams-{params.tailing_dummy_frames}"
|
||||||
if params.decoding_method == "beam_search":
|
if params.decoding_method == "beam_search":
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
else:
|
else:
|
||||||
@ -425,15 +479,24 @@ def main():
|
|||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
model_path = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" # noqa: E501
|
||||||
filenames = []
|
if os.path.isfile(model_path):
|
||||||
for i in range(start, params.epoch + 1):
|
load_checkpoint(model_path, model)
|
||||||
if start >= 0:
|
else:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
start = params.epoch - params.avg + 1
|
||||||
logging.info(f"averaging {filenames}")
|
filenames = []
|
||||||
model.to(device)
|
for i in range(start, params.epoch + 1):
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
|
if params.save_averaged_model:
|
||||||
|
save_checkpoint(
|
||||||
|
filename=model_path,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
@ -21,11 +21,13 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless/train.py \
|
./streaming_pruned_transducer_stateless/train.py \
|
||||||
|
--short-chunk-size=25 \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--full-libri 1 \
|
||||||
|
--num-epochs 50 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir pruned_transducer_stateless/exp \
|
--exp-dir streaming_pruned_transducer_stateless/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
"""
|
"""
|
||||||
@ -74,6 +76,12 @@ def get_parser():
|
|||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="chunk length of dynamic training",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size",
|
||||||
@ -252,6 +260,8 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
|
"dynamic_chunk_training": True,
|
||||||
|
"causal": True, # Now only causal convolution is verified
|
||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
@ -274,6 +284,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
|
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||||
|
short_chunk_size=params.short_chunk_size,
|
||||||
|
causal=params.causal,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -694,6 +694,42 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|||||||
return expaned_lengths >= lengths.unsqueeze(1)
|
return expaned_lengths >= lengths.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
# From https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42
|
||||||
|
def subsequent_chunk_mask(
|
||||||
|
size: int,
|
||||||
|
chunk_size: int,
|
||||||
|
num_left_chunks: int = -1,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||||
|
this is for streaming encoder
|
||||||
|
Args:
|
||||||
|
size (int): size of mask
|
||||||
|
chunk_size (int): size of chunk
|
||||||
|
num_left_chunks (int): number of left chunks
|
||||||
|
<0: use full chunk
|
||||||
|
>=0: use num_left_chunks
|
||||||
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: mask
|
||||||
|
Examples:
|
||||||
|
>>> subsequent_chunk_mask(4, 2)
|
||||||
|
[[1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1]]
|
||||||
|
"""
|
||||||
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||||
|
for i in range(size):
|
||||||
|
if num_left_chunks < 0:
|
||||||
|
start = 0
|
||||||
|
else:
|
||||||
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||||||
|
ending = min((i // chunk_size + 1) * chunk_size, size)
|
||||||
|
ret[i, start:ending] = True
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def l1_norm(x):
|
def l1_norm(x):
|
||||||
return torch.sum(torch.abs(x))
|
return torch.sum(torch.abs(x))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user