support streaming in conformer

This commit is contained in:
pkufool 2022-05-18 23:26:24 +08:00
parent 6f7860a0a6
commit 5bd2490b44
5 changed files with 392 additions and 31 deletions

View File

@ -249,6 +249,25 @@ def get_parser():
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--streaming-mode",
type=str2bool,
default=False,
help="""
""",
)
parser.add_argument(
"--right-chunk-size",
type=int,
default=16,
help="right context to attend during decoding",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context to attend during decoding",
)
return parser
@ -301,9 +320,18 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.streaming_mode:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.right_chunk_size,
left_context=params.left_context,
streaming_data=False
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -526,6 +554,10 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.streaming_mode:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-use-LG-{params.use_LG}"
params.suffix += f"-beam-{params.beam}"
@ -561,6 +593,10 @@ def main():
logging.info(params)
logging.info("About to create model")
# TODO(wei kang): make following config more elegant
params.dynamic_chunk_training=params.streaming_mode
params.short_chunk_size=25
params.num_left_chunks=params.left_context // params.right_chunk_size
model = get_transducer_model(params)
if params.iter > 0:

View File

@ -222,6 +222,29 @@ def get_parser():
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="chunk length of dynamic training",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="chunk length of dynamic training",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True
""",
)
return parser
@ -310,6 +333,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=True if params.dynamic_chunk_training else False,
)
return encoder

View File

@ -18,13 +18,77 @@
import copy
import math
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import Tensor, nn
from transformer import Transformer
from icefall.utils import make_pad_mask
from icefall.utils import make_pad_mask, subsequent_chunk_mask
class DecodeStates(object):
def __init__(self,
layers: int,
left_context: int,
dim: int,
init: bool = True,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device('cpu')):
self.layers = layers
self.left_context = left_context
self.dim = dim
self.dtype = dtype
self.device = device
if init:
# shape (layer, T, dim)
self.attn_cache = torch.zeros((layers, left_context, dim),
dtype=dtype,
device=device)
self.conv_cache = torch.zeros((layers, left_context, dim),
dtype=dtype,
device=device)
self.offset = torch.tensor([0], dtype=dtype, device=device)
@staticmethod
def stack(states: List['DecodeStates']) -> 'DecodeStates':
assert len(states) >= 1
obj = DecodeStates(layers=states[0].layers,
left_context=states[0].left_context,
dim=states[0].dim,
init=False,
dtype=states[0].dtype,
device=states[0].device)
attn_cache = []
conv_cache = []
offset = []
for i in range(len(states)):
attn_cache.append(states[i].attn_cache)
conv_cache.append(states[i].conv_cache)
offset.append(states[i].offset)
obj.attn_cache = torch.stack(attn_cache, dim=2)
obj.conv_cache = torch.stack(conv_cache, dim=2)
obj.offset = torch.stack(offset, dim=0)
return obj
@staticmethod
def unstack(states: 'DecodeStates') -> List['DecodeStates']:
results = []
attn_cache = torch.unbind(states.attn_cache, dim=2)
conv_cache = torch.unbind(states.conv_cache, dim=2)
offset = torch.unbind(states.offset, dim=0)
for i in range(states.attn_cache.size(2)):
obj = DecodeStates(layers=states.layers,
left_context=states.left_context,
dim=states.dim,
init=False,
dtype=states.dtype,
device=states.device)
obj.attn_cache = attn_cache[i]
obj.conv_cache = conv_cache[i]
obj.offset = offset[i]
results.append(obj)
return results
class Conformer(Transformer):
@ -56,6 +120,11 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
num_left_chunks: int = -1,
causal: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -70,6 +139,12 @@ class Conformer(Transformer):
vgg_frontend=vgg_frontend,
)
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.num_left_chunks = num_left_chunks
self.causal = causal
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
@ -79,6 +154,7 @@ class Conformer(Transformer):
dropout,
cnn_module_kernel,
normalize_before,
causal,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
@ -115,9 +191,29 @@ class Conformer(Transformer):
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
src_key_padding_mask = make_pad_mask(lengths)
mask = None
if self.dynamic_chunk_training:
assert (
self.causal
), "Causal convolution is required for streaming conformer."
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size,
num_left_chunks=self.num_left_chunks, device=x.device
)
x = self.encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
@ -128,6 +224,80 @@ class Conformer(Transformer):
return logits, lengths
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
decode_states: Optional[DecodeStates] = None,
chunk_size: int = 32,
left_context: int = 64,
streaming_data: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]:
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
if streaming_data:
assert (
decode_states is not None
), "Require cache when sending data in streaming mode"
assert (
left_context == decode_states.left_context
), f"""The given left_context must equal to the left_context in
`decode_states`, need {decode_states.left_context} given
{left_context}."""
src_key_padding_mask = make_pad_mask(lengths + left_context)
embed = self.encoder_embed(x)
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x = self.encoder(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
attn_cache=decode_states.attn_cache,
conv_cache=decode_states.conv_cache,
left_context=decode_states.left_context,
) # (T, B, F)
decode_states.offset += embed.size(0)
else:
assert decode_states is None
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item()
assert left_context % chunk_size == 0
num_left_chunks = left_context // chunk_size
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=num_left_chunks,
device=x.device
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
) # (T, N, C)
if self.normalize_before:
x = self.after_norm(x)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths, decode_states
class ConformerEncoderLayer(nn.Module):
"""
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
@ -156,6 +326,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
causal: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
@ -176,7 +347,9 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, causal=causal
)
self.norm_ff_macaron = nn.LayerNorm(
d_model
@ -201,6 +374,9 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
attn_cache: Optional[Tensor] = None,
conv_cache: Optional[Tensor] = None,
left_context: int = 0,
) -> Tensor:
"""
Pass the input through the encoder layer.
@ -233,13 +409,25 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_mha(src)
key = src
val = src
if not self.training and attn_cache is not None:
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([attn_cache, src], dim=0)
val = key
attn_cache = key
else:
assert left_context == 0
src_att = self.self_attn(
src,
src,
src,
key,
val,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
left_context=left_context,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
@ -249,7 +437,15 @@ class ConformerEncoderLayer(nn.Module):
residual = src
if self.normalize_before:
src = self.norm_conv(src)
src = residual + self.dropout(self.conv_module(src))
if not self.training and conv_cache is not None:
src = torch.cat([conv_cache, src], dim=0)
conv_cache = src
src = self.conv_module(src)
src = src[-residual.size(0) :, :, :] # noqa: E203
src = residual + self.dropout(src)
if not self.normalize_before:
src = self.norm_conv(src)
@ -264,7 +460,7 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before:
src = self.norm_final(src)
return src
return src, attn_cache, conv_cache
class ConformerEncoder(nn.Module):
@ -295,6 +491,9 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
attn_cache: Optional[Tensor] = None,
conv_cache: Optional[Tensor] = None,
left_context: int = 0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -314,13 +513,26 @@ class ConformerEncoder(nn.Module):
"""
output = src
for mod in self.layers:
output = mod(
if self.training:
assert left_context == 0
assert attn_cache is None
assert conv_cache is None
else:
assert left_context >= 0
for layer_index, mod in enumerate(self.layers):
output, a_cache, c_cache = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
attn_cache=None if attn_cache is None else attn_cache[layer_index],
conv_cache=None if conv_cache is None else conv_cache[layer_index],
left_context=left_context,
)
if attn_cache is not None and conv_cache is not None:
attn_cache[layer_index, ...] = a_cache[-left_context:, ...]
conv_cache[layer_index, ...] = c_cache[-left_context:, ...]
return output
@ -349,12 +561,13 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
def extend_pe(self, x: Tensor, context: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = x.size(1) + context
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
@ -364,9 +577,9 @@ class RelPositionalEncoding(torch.nn.Module):
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
@ -384,7 +597,11 @@ class RelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self,
x: torch.Tensor,
context: int = 0
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
@ -395,14 +612,15 @@ class RelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
self.extend_pe(x, context)
x = x * self.xscale
x_size_1 = x.size(1) + context
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
+ x_size_1,
]
return self.dropout(x), self.dropout(pos_emb)
@ -467,6 +685,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -525,9 +744,10 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
left_context=left_context,
)
def rel_shift(self, x: Tensor) -> Tensor:
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
"""Compute relative positional encoding.
Args:
@ -540,14 +760,17 @@ class RelPositionMultiheadAttention(nn.Module):
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
time2 = time1 + left_context
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
@ -569,6 +792,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -748,7 +972,9 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
@ -768,9 +994,10 @@ class RelPositionMultiheadAttention(nn.Module):
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
q_with_bias_v, p
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
attn_output_weights = (
matrix_ac + matrix_bd
@ -805,6 +1032,24 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
# If we are using dynamic_chunk_training and setting a limited
# num_left_chunks, the attention may only see the padding values which
# will also be masked out by `key_padding_mask`, at this circumstances,
# the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
# positions to avoid invalid loss value below.
if attn_mask is not None and attn_mask.dtype == torch.bool and \
key_padding_mask is not None:
combined_mask = attn_mask.unsqueeze(
0) | key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
combined_mask, 0.0)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
@ -842,12 +1087,17 @@ class ConvolutionModule(nn.Module):
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = False
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.causal = causal
self.pointwise_conv1 = nn.Conv1d(
channels,
@ -857,12 +1107,18 @@ class ConvolutionModule(nn.Module):
padding=0,
bias=bias,
)
self.lorder = kernel_size - 1
padding = (kernel_size - 1) // 2
if self.causal:
padding = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
padding=padding,
groups=channels,
bias=bias,
)
@ -895,6 +1151,11 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if self.causal and self.lorder > 0:
# Make depthwise_conv causal by
# manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
x = self.depthwise_conv(x)
# x is (batch, channels, time)
x = x.permute(0, 2, 1)

View File

@ -61,5 +61,6 @@ from .utils import (
setup_logger,
store_transcripts,
str2bool,
subsequent_chunk_mask,
write_error_stats,
)

View File

@ -693,6 +693,42 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
return expaned_lengths >= lengths.unsqueeze(1)
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def l1_norm(x):
return torch.sum(torch.abs(x))