mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
streaming conformer pruned transducer stateless
This commit is contained in:
parent
e03d237f9a
commit
ee359f4d13
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless/beam_search.py
|
@ -15,7 +15,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
@ -24,7 +25,7 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Transformer
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
@ -56,6 +57,12 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
dynamic_chunk_training: bool = True,
|
||||
short_chunk_threshold: float = 0.75,
|
||||
causal: bool = True,
|
||||
short_chunk_size: int = 25,
|
||||
use_codebook_loss: bool = False,
|
||||
num_codebooks: int = 4,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
@ -69,6 +76,9 @@ class Conformer(Transformer):
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
)
|
||||
self.dynamic_chunk_training = dynamic_chunk_training
|
||||
self.short_chunk_threshold = short_chunk_threshold
|
||||
self.short_chunk_size = short_chunk_size
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
@ -79,6 +89,7 @@ class Conformer(Transformer):
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
normalize_before,
|
||||
causal,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
self.normalize_before = normalize_before
|
||||
@ -112,9 +123,176 @@ class Conformer(Transformer):
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
|
||||
if self.dynamic_chunk_training:
|
||||
max_len = x.size(0)
|
||||
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0), chunk_size=chunk_size, device=x.device
|
||||
)
|
||||
x = self.encoder(
|
||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, B, F)
|
||||
else:
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, N, C)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
|
||||
logits = self.encoder_output_layer(x)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return logits, lengths
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
chunk_size: int = 16,
|
||||
simulate_streaming: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# x: [N, T, C]
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
if chunk_size < 0:
|
||||
# Deocding with full-right context.
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, B, F)
|
||||
else:
|
||||
# As temporarily in icefall only subsampling_rate == 4 is supported,
|
||||
# following parameters are hard-coded here.
|
||||
# Change it accordingly if other subsamling_rate are supported.
|
||||
# The first frame encoder_out needs at least 7 frames fbank feature
|
||||
embed_left_context = 7
|
||||
# Each successive frame needs 4 cached frames fbank feature
|
||||
subsampling_rate = 4
|
||||
# So the extra frames needed to generate the first frame encoder_out:
|
||||
embed_conv_context = embed_left_context - subsampling_rate
|
||||
|
||||
stride = chunk_size * subsampling_rate
|
||||
decoding_window = embed_conv_context + stride
|
||||
if simulate_streaming:
|
||||
# simulate chunk_by_chunk streaming decoding
|
||||
# Results of this branch should be identical to following
|
||||
# "else" branch.
|
||||
# But this branch is a little slower
|
||||
# as the feature is feeded chunk by chunk
|
||||
|
||||
# store the result of chunk_by_chunk decoding
|
||||
encoder_output = []
|
||||
|
||||
# caches
|
||||
pos_emb_positive = []
|
||||
pos_emb_negative = []
|
||||
pos_emb_central = None
|
||||
encoder_cache = [None for i in range(len(self.encoder.layers))]
|
||||
conv_cache = [None for i in range(len(self.encoder.layers))]
|
||||
|
||||
# start chunk_by_chunk decoding
|
||||
offset = 0
|
||||
feature = x
|
||||
num_frames = feature.size(1)
|
||||
for cur in range(
|
||||
0, num_frames - embed_left_context + 1, stride
|
||||
):
|
||||
end = min(cur + decoding_window, num_frames)
|
||||
cur_feature = feature[:, cur:end, :]
|
||||
cur_feature = self.encoder_embed(cur_feature)
|
||||
cur_embed, cur_pos_emb = self.encoder_pos(
|
||||
cur_feature, offset
|
||||
)
|
||||
cur_embed = cur_embed.permute(
|
||||
1, 0, 2
|
||||
) # (B, T, F) -> (T, B, F)
|
||||
|
||||
cur_T = cur_feature.size(1)
|
||||
if cur == 0:
|
||||
real_chunk_size = min(cur_T, chunk_size)
|
||||
assert (
|
||||
cur_pos_emb.size(1) == 2 * real_chunk_size - 1
|
||||
), f"{cur_pos_emb.size(1)} == 2 * {real_chunk_size} - 1"
|
||||
|
||||
# Extract the central pos embedding during first chunk
|
||||
pos_emb_central = cur_pos_emb[
|
||||
0, (real_chunk_size - 1), :
|
||||
].view(1, 1, -1)
|
||||
cur_T -= 1
|
||||
|
||||
# first chunk with chunk_size > 1
|
||||
# or not first chunk
|
||||
if (cur_T > 1 and cur == 0) or cur != 0:
|
||||
pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
|
||||
pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
|
||||
|
||||
assert pos_emb_positive[-1].size(0) == cur_T
|
||||
|
||||
pos_emb_pos = torch.cat(
|
||||
pos_emb_positive, dim=0
|
||||
).unsqueeze(0)
|
||||
pos_emb_neg = torch.cat(
|
||||
pos_emb_negative, dim=0
|
||||
).unsqueeze(0)
|
||||
cur_pos_emb = torch.cat(
|
||||
[pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
x = self.encoder.chunk_forward(
|
||||
cur_embed,
|
||||
cur_pos_emb,
|
||||
src_key_padding_mask=src_key_padding_mask[
|
||||
:, : offset + cur_embed.size(0)
|
||||
],
|
||||
encoder_cache=encoder_cache,
|
||||
conv_cache=conv_cache,
|
||||
offset=offset,
|
||||
) # (T, B, F)
|
||||
encoder_output.append(x)
|
||||
offset += cur_embed.size(0)
|
||||
|
||||
assert num_frames - end <= 3
|
||||
if num_frames != end:
|
||||
logging.info(
|
||||
f"The tailing {num_frames - end} frames fbank are not deocded."
|
||||
)
|
||||
x = torch.cat(encoder_output, dim=0)
|
||||
|
||||
else:
|
||||
# NOT simulate chunk_by_chunk decoding
|
||||
# Results of this branch should be identical to previous
|
||||
# simulate chunk_by_chunk decoding branch.
|
||||
# But this branch is faster.
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0), chunk_size=chunk_size, device=x.device
|
||||
)
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
) # (T, N, C)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
@ -153,6 +331,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
causal: bool = True,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
@ -173,7 +352,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
self.conv_module = ConvolutionModule(
|
||||
d_model, cnn_module_kernel, causal=causal
|
||||
)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
@ -263,13 +444,105 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
return src
|
||||
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
encoder_cache: Optional[Tensor] = None,
|
||||
conv_cache: Optional[Tensor] = None,
|
||||
offset=0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
|
||||
# macaron style feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
src = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(src)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
if encoder_cache is None:
|
||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||
key = src
|
||||
val = key
|
||||
encoder_cache = key
|
||||
else:
|
||||
key = torch.cat([encoder_cache, src], dim=0)
|
||||
val = key
|
||||
encoder_cache = key
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
key,
|
||||
val,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
offset=offset,
|
||||
)[0]
|
||||
src = residual + self.dropout(src_att)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
|
||||
# convolution module
|
||||
residual = src # [chunk_size, N, F] e.g. [8, 41, 512]
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
if conv_cache is not None:
|
||||
src = torch.cat([conv_cache, src], dim=0)
|
||||
conv_cache = src
|
||||
|
||||
src = self.conv_module(src)
|
||||
src = src[-residual.size(0) :, :, :] # noqa: E203
|
||||
|
||||
src = residual + self.dropout(src)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
|
||||
# feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
|
||||
if self.normalize_before:
|
||||
src = self.norm_final(src)
|
||||
|
||||
return src, encoder_cache, conv_cache
|
||||
|
||||
|
||||
class ConformerEncoder(nn.TransformerEncoder):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
@ -279,10 +552,11 @@ class ConformerEncoder(nn.Module):
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
def __init__(
|
||||
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
|
||||
) -> None:
|
||||
super(ConformerEncoder, self).__init__(
|
||||
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
@ -319,6 +593,55 @@ class ConformerEncoder(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
encoder_cache=None,
|
||||
conv_cache=None,
|
||||
offset=0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
output = src
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
output, e_cache, c_cache = mod.chunk_forward(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
encoder_cache=encoder_cache[layer_index],
|
||||
conv_cache=conv_cache[layer_index],
|
||||
offset=offset,
|
||||
)
|
||||
encoder_cache[layer_index] = e_cache
|
||||
conv_cache[layer_index] = c_cache
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -346,12 +669,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
def extend_pe(self, x: Tensor, offset: int = 0) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
x_size_1 = offset + x.size(1)
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
if self.pe.size(1) >= x_size_1 * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
@ -361,9 +685,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
pe_positive = torch.zeros(x_size_1, self.d_model)
|
||||
pe_negative = torch.zeros(x_size_1, self.d_model)
|
||||
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
@ -381,26 +705,35 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, offset: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
offset: time-index of the first frame of x.
|
||||
used to compute positional encoding in a streaming fasion.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
self.extend_pe(x, offset)
|
||||
x = x * self.xscale
|
||||
x_size_1 = offset + x.size(1)
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
- x_size_1
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
+ x_size_1,
|
||||
]
|
||||
x_T = x.size(1)
|
||||
if offset > 0:
|
||||
pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1)
|
||||
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
@ -464,6 +797,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
offset=0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -522,9 +856,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
def rel_shift(self, x: Tensor, offset=0) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
@ -533,18 +868,20 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
(note: time2 has the same value as time1, but it is for
|
||||
(note: time2 == time1 + offset, since it is for
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
time2 = time1 + offset
|
||||
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_size, num_heads, time1, time2),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
@ -566,6 +903,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
offset=0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -639,7 +977,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
@ -745,7 +1082,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||
1, 2
|
||||
@ -765,10 +1104,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
q_with_bias_v, p
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
matrix_bd = self.rel_shift(
|
||||
matrix_bd, offset=offset
|
||||
) # [B, head, time1, time2]
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
@ -835,11 +1175,16 @@ class ConvolutionModule(nn.Module):
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
causal (bool): Whether to use causal convlution (default=True).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
causal: bool = True,
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
@ -854,12 +1199,20 @@ class ConvolutionModule(nn.Module):
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
assert (
|
||||
causal
|
||||
), "Currently, causal convolution is required for streaming conformer."
|
||||
|
||||
# Manualy padding self.lorder zeros to the left during forward.
|
||||
self.lorder = kernel_size - 1
|
||||
padding = 0
|
||||
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
@ -892,6 +1245,10 @@ class ConvolutionModule(nn.Module):
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
if self.lorder > 0:
|
||||
# Make depthwise_conv causal by
|
||||
# manualy padding self.lorder zeros to the left
|
||||
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||
x = self.depthwise_conv(x)
|
||||
# x is (batch, channels, time)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
@ -18,18 +18,22 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
./streaming_pruned_transducer_stateless/decode.py \
|
||||
--simulate-streaming [True|False] \
|
||||
--right-chunk-size [1/4/8/16/32/-1] \
|
||||
--epoch 49 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--exp-dir ./streaming_pruned_transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
./streaming_pruned_transducer_stateless/decode.py \
|
||||
--simulate-streaming [True|False] \
|
||||
--right-chunk-size [1/4/8/16/32/-1] \
|
||||
--epoch 49 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--exp-dir ./streaming_pruned_transducer_stateless/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
@ -38,6 +42,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
@ -52,12 +57,17 @@ from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from model import Transducer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
load_checkpoint,
|
||||
save_checkpoint,
|
||||
)
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -67,6 +77,29 @@ def get_parser():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to split fbanks into chunks to simulate forward conformer"
|
||||
"in a streaming fashion",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tailing-dummy-frames",
|
||||
type=int,
|
||||
default=20,
|
||||
help="tailing dummy frames padded to the right,"
|
||||
"only used during decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--right-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="right context to attend during decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
@ -86,7 +119,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless/exp",
|
||||
default="streaming_pruned_transducer_stateless/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -145,6 +178,8 @@ def get_params() -> AttributeDict:
|
||||
# parameters for decoder
|
||||
"embedding_dim": 512,
|
||||
"env_info": get_env_info(),
|
||||
# model average
|
||||
"save_averaged_model": False,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -236,10 +271,26 @@ def decode_one_batch(
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
# Extra dummy tailing frames my reduce deletion error
|
||||
# example WITHOUT padding:
|
||||
# CHAPTER SEVEN ON THE RACES OF MAN
|
||||
# example WITH padding:
|
||||
# CHAPTER SEVEN ON THE RACES OF (MAN->*)
|
||||
tailing_frames = (
|
||||
torch.tensor([-23.0259])
|
||||
.expand([feature.size(0), params.tailing_dummy_frames, 80])
|
||||
.to(feature.device)
|
||||
)
|
||||
feature = torch.cat([feature, tailing_frames], dim=1)
|
||||
supervisions["num_frames"] += params.tailing_dummy_frames
|
||||
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.right_chunk_size,
|
||||
simulate_streaming=params.simulate_streaming,
|
||||
)
|
||||
hyps = []
|
||||
batch_size = encoder_out.size(0)
|
||||
@ -395,6 +446,9 @@ def main():
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
params.suffix += f"-chunk_size-{params.right_chunk_size}"
|
||||
params.suffix += f"-{params.simulate_streaming}"
|
||||
params.suffix += f"-tailing-dummy-frams-{params.tailing_dummy_frames}"
|
||||
if params.decoding_method == "beam_search":
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
@ -425,15 +479,24 @@ def main():
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
model_path = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" # noqa: E501
|
||||
if os.path.isfile(model_path):
|
||||
load_checkpoint(model_path, model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
if params.save_averaged_model:
|
||||
save_checkpoint(
|
||||
filename=model_path,
|
||||
model=model,
|
||||
)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
@ -21,11 +21,13 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless/train.py \
|
||||
./streaming_pruned_transducer_stateless/train.py \
|
||||
--short-chunk-size=25 \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--full-libri 1 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--exp-dir streaming_pruned_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
"""
|
||||
@ -74,6 +76,12 @@ def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="chunk length of dynamic training",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
@ -252,6 +260,8 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"dynamic_chunk_training": True,
|
||||
"causal": True, # Now only causal convolution is verified
|
||||
# parameters for decoder
|
||||
"embedding_dim": 512,
|
||||
# parameters for Noam
|
||||
@ -274,6 +284,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||
short_chunk_size=params.short_chunk_size,
|
||||
causal=params.causal,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
@ -694,6 +694,42 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
return expaned_lengths >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
# From https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42
|
||||
def subsequent_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
for i in range(size):
|
||||
if num_left_chunks < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||||
ending = min((i // chunk_size + 1) * chunk_size, size)
|
||||
ret[i, start:ending] = True
|
||||
return ret
|
||||
|
||||
|
||||
def l1_norm(x):
|
||||
return torch.sum(torch.abs(x))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user