streaming conformer pruned transducer stateless

This commit is contained in:
Guo Liyong 2022-03-07 16:49:46 +08:00
parent e03d237f9a
commit ee359f4d13
5 changed files with 521 additions and 51 deletions

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/beam_search.py

View File

@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import math
import warnings
from typing import Optional, Tuple
@ -24,7 +25,7 @@ import torch
from torch import Tensor, nn
from transformer import Transformer
from icefall.utils import make_pad_mask
from icefall.utils import make_pad_mask, subsequent_chunk_mask
class Conformer(Transformer):
@ -56,6 +57,12 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
dynamic_chunk_training: bool = True,
short_chunk_threshold: float = 0.75,
causal: bool = True,
short_chunk_size: int = 25,
use_codebook_loss: bool = False,
num_codebooks: int = 4,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -69,6 +76,9 @@ class Conformer(Transformer):
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
)
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -79,6 +89,7 @@ class Conformer(Transformer):
dropout,
cnn_module_kernel,
normalize_before,
causal,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
@ -112,9 +123,176 @@ class Conformer(Transformer):
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
if self.dynamic_chunk_training:
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size, device=x.device
)
x = self.encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
x = self.encoder(
x, pos_emb, src_key_padding_mask=src_key_padding_mask
) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
chunk_size: int = 16,
simulate_streaming: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
src_key_padding_mask = make_pad_mask(lengths)
if chunk_size < 0:
# Deocding with full-right context.
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
assert x.size(0) == lengths.max().item()
x = self.encoder(
x, pos_emb, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
# As temporarily in icefall only subsampling_rate == 4 is supported,
# following parameters are hard-coded here.
# Change it accordingly if other subsamling_rate are supported.
# The first frame encoder_out needs at least 7 frames fbank feature
embed_left_context = 7
# Each successive frame needs 4 cached frames fbank feature
subsampling_rate = 4
# So the extra frames needed to generate the first frame encoder_out:
embed_conv_context = embed_left_context - subsampling_rate
stride = chunk_size * subsampling_rate
decoding_window = embed_conv_context + stride
if simulate_streaming:
# simulate chunk_by_chunk streaming decoding
# Results of this branch should be identical to following
# "else" branch.
# But this branch is a little slower
# as the feature is feeded chunk by chunk
# store the result of chunk_by_chunk decoding
encoder_output = []
# caches
pos_emb_positive = []
pos_emb_negative = []
pos_emb_central = None
encoder_cache = [None for i in range(len(self.encoder.layers))]
conv_cache = [None for i in range(len(self.encoder.layers))]
# start chunk_by_chunk decoding
offset = 0
feature = x
num_frames = feature.size(1)
for cur in range(
0, num_frames - embed_left_context + 1, stride
):
end = min(cur + decoding_window, num_frames)
cur_feature = feature[:, cur:end, :]
cur_feature = self.encoder_embed(cur_feature)
cur_embed, cur_pos_emb = self.encoder_pos(
cur_feature, offset
)
cur_embed = cur_embed.permute(
1, 0, 2
) # (B, T, F) -> (T, B, F)
cur_T = cur_feature.size(1)
if cur == 0:
real_chunk_size = min(cur_T, chunk_size)
assert (
cur_pos_emb.size(1) == 2 * real_chunk_size - 1
), f"{cur_pos_emb.size(1)} == 2 * {real_chunk_size} - 1"
# Extract the central pos embedding during first chunk
pos_emb_central = cur_pos_emb[
0, (real_chunk_size - 1), :
].view(1, 1, -1)
cur_T -= 1
# first chunk with chunk_size > 1
# or not first chunk
if (cur_T > 1 and cur == 0) or cur != 0:
pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
assert pos_emb_positive[-1].size(0) == cur_T
pos_emb_pos = torch.cat(
pos_emb_positive, dim=0
).unsqueeze(0)
pos_emb_neg = torch.cat(
pos_emb_negative, dim=0
).unsqueeze(0)
cur_pos_emb = torch.cat(
[pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
dim=1,
)
x = self.encoder.chunk_forward(
cur_embed,
cur_pos_emb,
src_key_padding_mask=src_key_padding_mask[
:, : offset + cur_embed.size(0)
],
encoder_cache=encoder_cache,
conv_cache=conv_cache,
offset=offset,
) # (T, B, F)
encoder_output.append(x)
offset += cur_embed.size(0)
assert num_frames - end <= 3
if num_frames != end:
logging.info(
f"The tailing {num_frames - end} frames fbank are not deocded."
)
x = torch.cat(encoder_output, dim=0)
else:
# NOT simulate chunk_by_chunk decoding
# Results of this branch should be identical to previous
# simulate chunk_by_chunk decoding branch.
# But this branch is faster.
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item()
mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size, device=x.device
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
@ -153,6 +331,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
causal: bool = True,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
@ -173,7 +352,9 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, causal=causal
)
self.norm_ff_macaron = nn.LayerNorm(
d_model
@ -263,13 +444,105 @@ class ConformerEncoderLayer(nn.Module):
return src
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
encoder_cache: Optional[Tensor] = None,
conv_cache: Optional[Tensor] = None,
offset=0,
) -> Tensor:
"""
Pass the input through the encoder layer.
class ConformerEncoder(nn.Module):
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
if encoder_cache is None:
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = src
val = key
encoder_cache = key
else:
key = torch.cat([encoder_cache, src], dim=0)
val = key
encoder_cache = key
src_att = self.self_attn(
src,
key,
val,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
offset=offset,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src # [chunk_size, N, F] e.g. [8, 41, 512]
if self.normalize_before:
src = self.norm_conv(src)
if conv_cache is not None:
src = torch.cat([conv_cache, src], dim=0)
conv_cache = src
src = self.conv_module(src)
src = src[-residual.size(0) :, :, :] # noqa: E203
src = residual + self.dropout(src)
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src, encoder_cache, conv_cache
class ConformerEncoder(nn.TransformerEncoder):
r"""ConformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -279,10 +552,11 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
def __init__(
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
) -> None:
super(ConformerEncoder, self).__init__(
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
)
self.num_layers = num_layers
@ -319,6 +593,55 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask,
)
if self.norm is not None:
output = self.norm(output)
return output
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
encoder_cache=None,
conv_cache=None,
offset=0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
for layer_index, mod in enumerate(self.layers):
output, e_cache, c_cache = mod.chunk_forward(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
encoder_cache=encoder_cache[layer_index],
conv_cache=conv_cache[layer_index],
offset=offset,
)
encoder_cache[layer_index] = e_cache
conv_cache[layer_index] = c_cache
if self.norm is not None:
output = self.norm(output)
return output
@ -346,12 +669,13 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
def extend_pe(self, x: Tensor, offset: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = offset + x.size(1)
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
@ -361,9 +685,9 @@ class RelPositionalEncoding(torch.nn.Module):
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
@ -381,26 +705,35 @@ class RelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self, x: torch.Tensor, offset: int = 0
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
offset: time-index of the first frame of x.
used to compute positional encoding in a streaming fasion.
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
self.extend_pe(x, offset)
x = x * self.xscale
x_size_1 = offset + x.size(1)
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
+ x_size_1,
]
x_T = x.size(1)
if offset > 0:
pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1)
return self.dropout(x), self.dropout(pos_emb)
@ -464,6 +797,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
offset=0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -522,9 +856,10 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
offset=offset,
)
def rel_shift(self, x: Tensor) -> Tensor:
def rel_shift(self, x: Tensor, offset=0) -> Tensor:
"""Compute relative positional encoding.
Args:
@ -533,18 +868,20 @@ class RelPositionMultiheadAttention(nn.Module):
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
(note: time2 == time1 + offset, since it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
time2 = time1 + offset
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
@ -566,6 +903,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
offset=0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -639,7 +977,6 @@ class RelPositionMultiheadAttention(nn.Module):
if _b is not None:
_b = _b[_start:_end]
q = nn.functional.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
@ -745,7 +1082,9 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
@ -765,10 +1104,11 @@ class RelPositionMultiheadAttention(nn.Module):
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
q_with_bias_v, p
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
matrix_bd = self.rel_shift(
matrix_bd, offset=offset
) # [B, head, time1, time2]
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
@ -835,11 +1175,16 @@ class ConvolutionModule(nn.Module):
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
causal (bool): Whether to use causal convlution (default=True).
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = True,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
@ -854,12 +1199,20 @@ class ConvolutionModule(nn.Module):
padding=0,
bias=bias,
)
assert (
causal
), "Currently, causal convolution is required for streaming conformer."
# Manualy padding self.lorder zeros to the left during forward.
self.lorder = kernel_size - 1
padding = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
padding=padding,
groups=channels,
bias=bias,
)
@ -892,6 +1245,10 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if self.lorder > 0:
# Make depthwise_conv causal by
# manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)

View File

@ -18,18 +18,22 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
./streaming_pruned_transducer_stateless/decode.py \
--simulate-streaming [True|False] \
--right-chunk-size [1/4/8/16/32/-1] \
--epoch 49 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--exp-dir ./streaming_pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
./streaming_pruned_transducer_stateless/decode.py \
--simulate-streaming [True|False] \
--right-chunk-size [1/4/8/16/32/-1] \
--epoch 49 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--exp-dir ./streaming_pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
@ -38,6 +42,7 @@ Usage:
import argparse
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@ -52,12 +57,17 @@ from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
load_checkpoint,
save_checkpoint,
)
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -67,6 +77,29 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="Whether to split fbanks into chunks to simulate forward conformer"
"in a streaming fashion",
)
parser.add_argument(
"--tailing-dummy-frames",
type=int,
default=20,
help="tailing dummy frames padded to the right,"
"only used during decoding",
)
parser.add_argument(
"--right-chunk-size",
type=int,
default=16,
help="right context to attend during decoding",
)
parser.add_argument(
"--epoch",
type=int,
@ -86,7 +119,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
default="streaming_pruned_transducer_stateless/exp",
help="The experiment dir",
)
@ -145,6 +178,8 @@ def get_params() -> AttributeDict:
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
# model average
"save_averaged_model": False,
}
)
return params
@ -236,10 +271,26 @@ def decode_one_batch(
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
# Extra dummy tailing frames my reduce deletion error
# example WITHOUT padding:
# CHAPTER SEVEN ON THE RACES OF MAN
# example WITH padding:
# CHAPTER SEVEN ON THE RACES OF (MAN->*)
tailing_frames = (
torch.tensor([-23.0259])
.expand([feature.size(0), params.tailing_dummy_frames, 80])
.to(feature.device)
)
feature = torch.cat([feature, tailing_frames], dim=1)
supervisions["num_frames"] += params.tailing_dummy_frames
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.right_chunk_size,
simulate_streaming=params.simulate_streaming,
)
hyps = []
batch_size = encoder_out.size(0)
@ -395,6 +446,9 @@ def main():
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix += f"-chunk_size-{params.right_chunk_size}"
params.suffix += f"-{params.simulate_streaming}"
params.suffix += f"-tailing-dummy-frams-{params.tailing_dummy_frames}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
else:
@ -425,15 +479,24 @@ def main():
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model_path = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" # noqa: E501
if os.path.isfile(model_path):
load_checkpoint(model_path, model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
if params.save_averaged_model:
save_checkpoint(
filename=model_path,
model=model,
)
model.to(device)
model.eval()
model.device = device

View File

@ -21,11 +21,13 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \
./streaming_pruned_transducer_stateless/train.py \
--short-chunk-size=25 \
--world-size 4 \
--num-epochs 30 \
--full-libri 1 \
--num-epochs 50 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--exp-dir streaming_pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300
"""
@ -74,6 +76,12 @@ def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="chunk length of dynamic training",
)
parser.add_argument(
"--world-size",
@ -252,6 +260,8 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"dynamic_chunk_training": True,
"causal": True, # Now only causal convolution is verified
# parameters for decoder
"embedding_dim": 512,
# parameters for Noam
@ -274,6 +284,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
causal=params.causal,
)
return encoder

View File

@ -694,6 +694,42 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
return expaned_lengths >= lengths.unsqueeze(1)
# From https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def l1_norm(x):
return torch.sum(torch.abs(x))