mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
support dynamic chunk streaming training
This commit is contained in:
parent
29e407fd04
commit
964af2caf7
@ -32,7 +32,7 @@ from scaling import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
class Conformer(EncoderInterface):
|
class Conformer(EncoderInterface):
|
||||||
@ -46,8 +46,27 @@ class Conformer(EncoderInterface):
|
|||||||
num_encoder_layers (int): number of encoder layers
|
num_encoder_layers (int): number of encoder layers
|
||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
layer_dropout (float): layer-dropout rate.
|
layer_dropout (float): layer-dropout rate.
|
||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
|
||||||
|
you want to train a streaming model, this is expected to be True.
|
||||||
|
When setting True, it will use a masking strategy to make the attention
|
||||||
|
see only limited left and right context.
|
||||||
|
short_chunk_threshold (float): a threshold to determinize the chunk size
|
||||||
|
to be used in masking training, if the randomly generated chunk size
|
||||||
|
is greater than ``max_len * short_chunk_threshold`` (max_len is the
|
||||||
|
max sequence length of current batch) then it will use
|
||||||
|
full context in training (i.e. with chunk size equals to max_len).
|
||||||
|
This will be used only when dynamic_chunk_training is True.
|
||||||
|
short_chunk_size (int): see docs above, if the randomly generated chunk
|
||||||
|
size equals to or less than ``max_len * short_chunk_threshold``, the
|
||||||
|
chunk size will be sampled uniformly from 1 to short_chunk_size.
|
||||||
|
This also will be used only when dynamic_chunk_training is True.
|
||||||
|
num_left_chunks (int): the left context (in chunks) attention can see, the
|
||||||
|
chunk size is decided by short_chunk_threshold and short_chunk_size.
|
||||||
|
A minus value means seeing full left context.
|
||||||
|
This also will be used only when dynamic_chunk_training is True.
|
||||||
|
causal (bool): Whether to use causal convolution in conformer encoder
|
||||||
|
layer. This MUST be True when using dynamic_chunk_training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -62,6 +81,11 @@ class Conformer(EncoderInterface):
|
|||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
aux_layer_period: int = 3,
|
aux_layer_period: int = 3,
|
||||||
|
dynamic_chunk_training: bool = False,
|
||||||
|
short_chunk_threshold: float = 0.75,
|
||||||
|
short_chunk_size: int = 25,
|
||||||
|
num_left_chunks: int = -1,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__()
|
super(Conformer, self).__init__()
|
||||||
|
|
||||||
@ -79,6 +103,15 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
|
self.encoder_layers = num_encoder_layers
|
||||||
|
self.d_model = d_model
|
||||||
|
self.cnn_module_kernel = cnn_module_kernel
|
||||||
|
self.causal = causal
|
||||||
|
self.dynamic_chunk_training = dynamic_chunk_training
|
||||||
|
self.short_chunk_threshold = short_chunk_threshold
|
||||||
|
self.short_chunk_size = short_chunk_size
|
||||||
|
self.num_left_chunks = num_left_chunks
|
||||||
|
|
||||||
encoder_layer = ConformerEncoderLayer(
|
encoder_layer = ConformerEncoderLayer(
|
||||||
d_model,
|
d_model,
|
||||||
nhead,
|
nhead,
|
||||||
@ -119,11 +152,39 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(
|
if self.dynamic_chunk_training:
|
||||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
assert (
|
||||||
) # (T, N, C)
|
self.causal
|
||||||
|
), "Causal convolution is required for streaming conformer."
|
||||||
|
max_len = x.size(0)
|
||||||
|
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||||
|
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||||
|
chunk_size = max_len
|
||||||
|
else:
|
||||||
|
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||||
|
|
||||||
|
mask = ~subsequent_chunk_mask(
|
||||||
|
size=x.size(0),
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
num_left_chunks=self.num_left_chunks,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
x = self.encoder(
|
||||||
|
x,
|
||||||
|
pos_emb,
|
||||||
|
mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
) # (T, N, C)
|
||||||
|
else:
|
||||||
|
x = self.encoder(
|
||||||
|
x,
|
||||||
|
pos_emb,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
@ -141,6 +202,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
|
causal (bool): Whether to use causal convolution in conformer encoder
|
||||||
|
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||||
@ -157,6 +220,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
@ -184,7 +248,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(
|
||||||
|
d_model, cnn_module_kernel, causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
@ -250,7 +316,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.dropout(self.conv_module(src))
|
conv, _ = self.conv_module(src)
|
||||||
|
src = src + self.dropout(conv)
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
@ -832,6 +899,39 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||||
|
|
||||||
|
# If we are using dynamic_chunk_training and setting a limited
|
||||||
|
# num_left_chunks, the attention may only see the padding values which
|
||||||
|
# will also be masked out by `key_padding_mask`, at this circumstances,
|
||||||
|
# the whole column of `attn_output_weights` will be `-inf`
|
||||||
|
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
|
||||||
|
# positions to avoid invalid loss value below.
|
||||||
|
if (
|
||||||
|
attn_mask is not None
|
||||||
|
and attn_mask.dtype == torch.bool
|
||||||
|
and key_padding_mask is not None
|
||||||
|
):
|
||||||
|
if attn_mask.size(0) != 1:
|
||||||
|
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
|
||||||
|
combined_mask = attn_mask | key_padding_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
).unsqueeze(2)
|
||||||
|
else:
|
||||||
|
# attn_mask.shape == (1, tgt_len, src_len)
|
||||||
|
combined_mask = attn_mask.unsqueeze(
|
||||||
|
0
|
||||||
|
) | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz, num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
|
combined_mask, 0.0
|
||||||
|
)
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
|
||||||
attn_output_weights = nn.functional.dropout(
|
attn_output_weights = nn.functional.dropout(
|
||||||
attn_output_weights, p=dropout_p, training=training
|
attn_output_weights, p=dropout_p, training=training
|
||||||
)
|
)
|
||||||
@ -865,17 +965,24 @@ class ConvolutionModule(nn.Module):
|
|||||||
channels (int): The number of channels of conv layers.
|
channels (int): The number of channels of conv layers.
|
||||||
kernel_size (int): Kernerl size of conv layers.
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
causal (bool): Whether to use causal convolution.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, channels: int, kernel_size: int, bias: bool = True
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
self.pointwise_conv1 = ScaledConv1d(
|
self.pointwise_conv1 = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
2 * channels,
|
2 * channels,
|
||||||
@ -902,12 +1009,17 @@ class ConvolutionModule(nn.Module):
|
|||||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.lorder = kernel_size - 1
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
if self.causal:
|
||||||
|
padding = 0
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
self.depthwise_conv = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=(kernel_size - 1) // 2,
|
padding=padding,
|
||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
@ -928,15 +1040,20 @@ class ConvolutionModule(nn.Module):
|
|||||||
initial_scale=0.25,
|
initial_scale=0.25,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor, cache: Optional[Tensor] = None) -> Tensor:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor (#time, batch, channels).
|
x: Input tensor (#time, batch, channels).
|
||||||
|
cache: The cache of depthwise_conv, only used in real streaming
|
||||||
|
decoding.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Output tensor (#time, batch, channels).
|
Tensor: Output tensor (#time, batch, channels).
|
||||||
|
If cache is None return the output tensor (#time, batch, channels).
|
||||||
|
If cache is not None, return a tuple of Tensor, the first one is
|
||||||
|
the output tensor (#time, batch, channels), the second one is the
|
||||||
|
new cache for next chunk (#kernel_size - 1, batch, channels).
|
||||||
"""
|
"""
|
||||||
# exchange the temporal dimension and the feature dimension
|
# exchange the temporal dimension and the feature dimension
|
||||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
@ -948,6 +1065,19 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
|
if self.causal and self.lorder > 0:
|
||||||
|
if cache is None:
|
||||||
|
# Make depthwise_conv causal by
|
||||||
|
# manualy padding self.lorder zeros to the left
|
||||||
|
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not self.training
|
||||||
|
), "Cache should be None in training time"
|
||||||
|
assert cache.size(0) == self.lorder
|
||||||
|
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
||||||
|
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
||||||
|
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
|
|
||||||
x = self.deriv_balancer2(x)
|
x = self.deriv_balancer2(x)
|
||||||
@ -955,7 +1085,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
return x.permute(2, 0, 1)
|
# torch.jit.script requires return types be the same as annotated above
|
||||||
|
if cache is None:
|
||||||
|
cache = torch.empty(0)
|
||||||
|
|
||||||
|
return x.permute(2, 0, 1), cache
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
@ -134,6 +134,40 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -408,6 +442,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||||
|
short_chunk_size=params.short_chunk_size,
|
||||||
|
num_left_chunks=params.num_left_chunks,
|
||||||
|
causal=params.causal_convolution,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -892,6 +930,11 @@ def run(rank, world_size, args):
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.dynamic_chunk_training:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "dynamic_chunk_training requires causal convolution"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user