streaming conformer pruned transducer stateless

This commit is contained in:
Guo Liyong 2022-03-07 16:49:46 +08:00
parent e03d237f9a
commit ee359f4d13
5 changed files with 521 additions and 51 deletions

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/beam_search.py

View File

@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import logging
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
@ -24,7 +25,7 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
from transformer import Transformer from transformer import Transformer
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask, subsequent_chunk_mask
class Conformer(Transformer): class Conformer(Transformer):
@ -56,6 +57,12 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
dynamic_chunk_training: bool = True,
short_chunk_threshold: float = 0.75,
causal: bool = True,
short_chunk_size: int = 25,
use_codebook_loss: bool = False,
num_codebooks: int = 4,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -69,6 +76,9 @@ class Conformer(Transformer):
normalize_before=normalize_before, normalize_before=normalize_before,
vgg_frontend=vgg_frontend, vgg_frontend=vgg_frontend,
) )
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -79,6 +89,7 @@ class Conformer(Transformer):
dropout, dropout,
cnn_module_kernel, cnn_module_kernel,
normalize_before, normalize_before,
causal,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before self.normalize_before = normalize_before
@ -112,9 +123,176 @@ class Conformer(Transformer):
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2 lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) if self.dynamic_chunk_training:
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size, device=x.device
)
x = self.encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
x = self.encoder(
x, pos_emb, src_key_padding_mask=src_key_padding_mask
) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
chunk_size: int = 16,
simulate_streaming: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
src_key_padding_mask = make_pad_mask(lengths)
if chunk_size < 0:
# Deocding with full-right context.
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
assert x.size(0) == lengths.max().item()
x = self.encoder(
x, pos_emb, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
# As temporarily in icefall only subsampling_rate == 4 is supported,
# following parameters are hard-coded here.
# Change it accordingly if other subsamling_rate are supported.
# The first frame encoder_out needs at least 7 frames fbank feature
embed_left_context = 7
# Each successive frame needs 4 cached frames fbank feature
subsampling_rate = 4
# So the extra frames needed to generate the first frame encoder_out:
embed_conv_context = embed_left_context - subsampling_rate
stride = chunk_size * subsampling_rate
decoding_window = embed_conv_context + stride
if simulate_streaming:
# simulate chunk_by_chunk streaming decoding
# Results of this branch should be identical to following
# "else" branch.
# But this branch is a little slower
# as the feature is feeded chunk by chunk
# store the result of chunk_by_chunk decoding
encoder_output = []
# caches
pos_emb_positive = []
pos_emb_negative = []
pos_emb_central = None
encoder_cache = [None for i in range(len(self.encoder.layers))]
conv_cache = [None for i in range(len(self.encoder.layers))]
# start chunk_by_chunk decoding
offset = 0
feature = x
num_frames = feature.size(1)
for cur in range(
0, num_frames - embed_left_context + 1, stride
):
end = min(cur + decoding_window, num_frames)
cur_feature = feature[:, cur:end, :]
cur_feature = self.encoder_embed(cur_feature)
cur_embed, cur_pos_emb = self.encoder_pos(
cur_feature, offset
)
cur_embed = cur_embed.permute(
1, 0, 2
) # (B, T, F) -> (T, B, F)
cur_T = cur_feature.size(1)
if cur == 0:
real_chunk_size = min(cur_T, chunk_size)
assert (
cur_pos_emb.size(1) == 2 * real_chunk_size - 1
), f"{cur_pos_emb.size(1)} == 2 * {real_chunk_size} - 1"
# Extract the central pos embedding during first chunk
pos_emb_central = cur_pos_emb[
0, (real_chunk_size - 1), :
].view(1, 1, -1)
cur_T -= 1
# first chunk with chunk_size > 1
# or not first chunk
if (cur_T > 1 and cur == 0) or cur != 0:
pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
assert pos_emb_positive[-1].size(0) == cur_T
pos_emb_pos = torch.cat(
pos_emb_positive, dim=0
).unsqueeze(0)
pos_emb_neg = torch.cat(
pos_emb_negative, dim=0
).unsqueeze(0)
cur_pos_emb = torch.cat(
[pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
dim=1,
)
x = self.encoder.chunk_forward(
cur_embed,
cur_pos_emb,
src_key_padding_mask=src_key_padding_mask[
:, : offset + cur_embed.size(0)
],
encoder_cache=encoder_cache,
conv_cache=conv_cache,
offset=offset,
) # (T, B, F)
encoder_output.append(x)
offset += cur_embed.size(0)
assert num_frames - end <= 3
if num_frames != end:
logging.info(
f"The tailing {num_frames - end} frames fbank are not deocded."
)
x = torch.cat(encoder_output, dim=0)
else:
# NOT simulate chunk_by_chunk decoding
# Results of this branch should be identical to previous
# simulate chunk_by_chunk decoding branch.
# But this branch is faster.
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item()
mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size, device=x.device
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
) # (T, N, C)
if self.normalize_before: if self.normalize_before:
x = self.after_norm(x) x = self.after_norm(x)
@ -153,6 +331,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
causal: bool = True,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
@ -173,7 +352,9 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model), nn.Linear(dim_feedforward, d_model),
) )
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, causal=causal
)
self.norm_ff_macaron = nn.LayerNorm( self.norm_ff_macaron = nn.LayerNorm(
d_model d_model
@ -263,13 +444,105 @@ class ConformerEncoderLayer(nn.Module):
return src return src
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
encoder_cache: Optional[Tensor] = None,
conv_cache: Optional[Tensor] = None,
offset=0,
) -> Tensor:
"""
Pass the input through the encoder layer.
class ConformerEncoder(nn.Module): Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
if encoder_cache is None:
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = src
val = key
encoder_cache = key
else:
key = torch.cat([encoder_cache, src], dim=0)
val = key
encoder_cache = key
src_att = self.self_attn(
src,
key,
val,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
offset=offset,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src # [chunk_size, N, F] e.g. [8, 41, 512]
if self.normalize_before:
src = self.norm_conv(src)
if conv_cache is not None:
src = torch.cat([conv_cache, src], dim=0)
conv_cache = src
src = self.conv_module(src)
src = src[-residual.size(0) :, :, :] # noqa: E203
src = residual + self.dropout(src)
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src, encoder_cache, conv_cache
class ConformerEncoder(nn.TransformerEncoder):
r"""ConformerEncoder is a stack of N encoder layers r"""ConformerEncoder is a stack of N encoder layers
Args: Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required). encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required). num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples:: Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -279,10 +552,11 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb) >>> out = conformer_encoder(src, pos_emb)
""" """
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: def __init__(
super().__init__() self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
self.layers = nn.ModuleList( ) -> None:
[copy.deepcopy(encoder_layer) for i in range(num_layers)] super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
) )
self.num_layers = num_layers self.num_layers = num_layers
@ -319,6 +593,55 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
if self.norm is not None:
output = self.norm(output)
return output
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
encoder_cache=None,
conv_cache=None,
offset=0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
for layer_index, mod in enumerate(self.layers):
output, e_cache, c_cache = mod.chunk_forward(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
encoder_cache=encoder_cache[layer_index],
conv_cache=conv_cache[layer_index],
offset=offset,
)
encoder_cache[layer_index] = e_cache
conv_cache[layer_index] = c_cache
if self.norm is not None:
output = self.norm(output)
return output return output
@ -346,12 +669,13 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None: def extend_pe(self, x: Tensor, offset: int = 0) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
x_size_1 = offset + x.size(1)
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device x.device
@ -361,9 +685,9 @@ class RelPositionalEncoding(torch.nn.Module):
# Suppose `i` means to the position of query vecotr and `j` means the # Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys # position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j). # are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model) pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model) pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model) * -(math.log(10000.0) / self.d_model)
@ -381,26 +705,35 @@ class RelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1) pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: def forward(
self, x: torch.Tensor, offset: int = 0
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding. """Add positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (batch, time, `*`). x (torch.Tensor): Input tensor (batch, time, `*`).
offset: time-index of the first frame of x.
used to compute positional encoding in a streaming fasion.
Returns: Returns:
torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
""" """
self.extend_pe(x) self.extend_pe(x, offset)
x = x * self.xscale x = x * self.xscale
x_size_1 = offset + x.size(1)
pos_emb = self.pe[ pos_emb = self.pe[
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
- x.size(1) - x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203 + 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1), + x_size_1,
] ]
x_T = x.size(1)
if offset > 0:
pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1)
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -464,6 +797,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = True,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
offset=0,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -522,9 +856,10 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=need_weights, need_weights=need_weights,
attn_mask=attn_mask, attn_mask=attn_mask,
offset=offset,
) )
def rel_shift(self, x: Tensor) -> Tensor: def rel_shift(self, x: Tensor, offset=0) -> Tensor:
"""Compute relative positional encoding. """Compute relative positional encoding.
Args: Args:
@ -533,18 +868,20 @@ class RelPositionMultiheadAttention(nn.Module):
Returns: Returns:
Tensor: tensor of shape (batch, head, time1, time2) Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for (note: time2 == time1 + offset, since it is for
the key, while time1 is for the query). the key, while time1 is for the query).
""" """
(batch_size, num_heads, time1, n) = x.shape (batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1 time2 = time1 + offset
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
# Note: TorchScript requires explicit arg for stride() # Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0) batch_stride = x.stride(0)
head_stride = x.stride(1) head_stride = x.stride(1)
time1_stride = x.stride(2) time1_stride = x.stride(2)
n_stride = x.stride(3) n_stride = x.stride(3)
return x.as_strided( return x.as_strided(
(batch_size, num_heads, time1, time1), (batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride), (batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1), storage_offset=n_stride * (time1 - 1),
) )
@ -566,6 +903,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = True,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
offset=0,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -639,7 +977,6 @@ class RelPositionMultiheadAttention(nn.Module):
if _b is not None: if _b is not None:
_b = _b[_start:_end] _b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b) q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias # This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias _b = in_proj_bias
_start = embed_dim _start = embed_dim
@ -745,7 +1082,9 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb_bsz = pos_emb.size(0) pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1 assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self.pos_bias_u).transpose( q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2 1, 2
@ -765,10 +1104,11 @@ class RelPositionMultiheadAttention(nn.Module):
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1) q_with_bias_v, p
) # (batch, head, time1, 2*time1-1) ) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd) matrix_bd = self.rel_shift(
matrix_bd, offset=offset
) # [B, head, time1, time2]
attn_output_weights = ( attn_output_weights = (
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
@ -835,11 +1175,16 @@ class ConvolutionModule(nn.Module):
channels (int): The number of channels of conv layers. channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers. kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True). bias (bool): Whether to use bias in conv layers (default=True).
causal (bool): Whether to use causal convlution (default=True).
""" """
def __init__( def __init__(
self, channels: int, kernel_size: int, bias: bool = True self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = True,
) -> None: ) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
@ -854,12 +1199,20 @@ class ConvolutionModule(nn.Module):
padding=0, padding=0,
bias=bias, bias=bias,
) )
assert (
causal
), "Currently, causal convolution is required for streaming conformer."
# Manualy padding self.lorder zeros to the left during forward.
self.lorder = kernel_size - 1
padding = 0
self.depthwise_conv = nn.Conv1d( self.depthwise_conv = nn.Conv1d(
channels, channels,
channels, channels,
kernel_size, kernel_size,
stride=1, stride=1,
padding=(kernel_size - 1) // 2, padding=padding,
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
@ -892,6 +1245,10 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time) x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv # 1D Depthwise Conv
if self.lorder > 0:
# Make depthwise_conv causal by
# manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
# x is (batch, channels, time) # x is (batch, channels, time)
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)

View File

@ -18,18 +18,22 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless/decode.py \ ./streaming_pruned_transducer_stateless/decode.py \
--epoch 28 \ --simulate-streaming [True|False] \
--right-chunk-size [1/4/8/16/32/-1] \
--epoch 49 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./streaming_pruned_transducer_stateless/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./pruned_transducer_stateless/decode.py \ ./streaming_pruned_transducer_stateless/decode.py \
--epoch 28 \ --simulate-streaming [True|False] \
--right-chunk-size [1/4/8/16/32/-1] \
--epoch 49 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./streaming_pruned_transducer_stateless/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -38,6 +42,7 @@ Usage:
import argparse import argparse
import logging import logging
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -52,12 +57,17 @@ from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from model import Transducer from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
load_checkpoint,
save_checkpoint,
)
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -67,6 +77,29 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="Whether to split fbanks into chunks to simulate forward conformer"
"in a streaming fashion",
)
parser.add_argument(
"--tailing-dummy-frames",
type=int,
default=20,
help="tailing dummy frames padded to the right,"
"only used during decoding",
)
parser.add_argument(
"--right-chunk-size",
type=int,
default=16,
help="right context to attend during decoding",
)
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
@ -86,7 +119,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless/exp", default="streaming_pruned_transducer_stateless/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -145,6 +178,8 @@ def get_params() -> AttributeDict:
# parameters for decoder # parameters for decoder
"embedding_dim": 512, "embedding_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
# model average
"save_averaged_model": False,
} }
) )
return params return params
@ -236,10 +271,26 @@ def decode_one_batch(
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
# Extra dummy tailing frames my reduce deletion error
# example WITHOUT padding:
# CHAPTER SEVEN ON THE RACES OF MAN
# example WITH padding:
# CHAPTER SEVEN ON THE RACES OF (MAN->*)
tailing_frames = (
torch.tensor([-23.0259])
.expand([feature.size(0), params.tailing_dummy_frames, 80])
.to(feature.device)
)
feature = torch.cat([feature, tailing_frames], dim=1)
supervisions["num_frames"] += params.tailing_dummy_frames
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder.streaming_forward(
x=feature, x_lens=feature_lens x=feature,
x_lens=feature_lens,
chunk_size=params.right_chunk_size,
simulate_streaming=params.simulate_streaming,
) )
hyps = [] hyps = []
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -395,6 +446,9 @@ def main():
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix += f"-chunk_size-{params.right_chunk_size}"
params.suffix += f"-{params.simulate_streaming}"
params.suffix += f"-tailing-dummy-frams-{params.tailing_dummy_frames}"
if params.decoding_method == "beam_search": if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
@ -425,15 +479,24 @@ def main():
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
start = params.epoch - params.avg + 1 model_path = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" # noqa: E501
filenames = [] if os.path.isfile(model_path):
for i in range(start, params.epoch + 1): load_checkpoint(model_path, model)
if start >= 0: else:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") start = params.epoch - params.avg + 1
logging.info(f"averaging {filenames}") filenames = []
model.to(device) for i in range(start, params.epoch + 1):
model.load_state_dict(average_checkpoints(filenames, device=device)) if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.save_averaged_model:
save_checkpoint(
filename=model_path,
model=model,
)
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device

View File

@ -21,11 +21,13 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \ ./streaming_pruned_transducer_stateless/train.py \
--short-chunk-size=25 \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --full-libri 1 \
--num-epochs 50 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \ --exp-dir streaming_pruned_transducer_stateless/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
""" """
@ -74,6 +76,12 @@ def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
) )
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="chunk length of dynamic training",
)
parser.add_argument( parser.add_argument(
"--world-size", "--world-size",
@ -252,6 +260,8 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"dynamic_chunk_training": True,
"causal": True, # Now only causal convolution is verified
# parameters for decoder # parameters for decoder
"embedding_dim": 512, "embedding_dim": 512,
# parameters for Noam # parameters for Noam
@ -274,6 +284,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
causal=params.causal,
) )
return encoder return encoder

View File

@ -694,6 +694,42 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
return expaned_lengths >= lengths.unsqueeze(1) return expaned_lengths >= lengths.unsqueeze(1)
# From https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def l1_norm(x): def l1_norm(x):
return torch.sum(torch.abs(x)) return torch.sum(torch.abs(x))