support streaming on pruned_transducer_stateless2; add delay penalty; fixes for decode states

This commit is contained in:
pkufool 2022-05-22 17:14:17 +08:00
parent 118b09463d
commit 7cc697c03a
8 changed files with 637 additions and 162 deletions

View File

@ -66,6 +66,8 @@ class Transducer(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
delay_penalty: float = 0.0,
return_sym_delay: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -136,10 +138,31 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale, lm_only_scale=lm_scale,
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum", reduction="sum",
return_grad=True, return_grad=True,
) )
sym_delay = None
if return_sym_delay:
B, S, T0 = px_grad.shape
T = T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2,
dtype=px_grad.dtype,
device=px_grad.device,
).expand(B, 1, 1)
total_syms = S * B
else:
offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2])
offset = torch.arange(
T0, device=px_grad.device
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms
# ranges : [B, T, prune_range] # ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad, px_grad=px_grad,
@ -163,7 +186,8 @@ class Transducer(nn.Module):
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum", reduction="sum",
) )
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss, sym_delay)

View File

@ -269,6 +269,25 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.", help="How many left context can be seen in chunks when calculating attention.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser return parser
@ -536,14 +555,17 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
sym_delay = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, sym_delay = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss loss = params.simple_loss_scale * simple_loss + pruned_loss
@ -561,6 +583,9 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if sym_delay is not None:
info["sym_delay"] = sym_delay.detatch().cpu().item()
return loss, info return loss, info

View File

@ -32,7 +32,7 @@ from scaling import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask, subsequent_chunk_mask
class Conformer(EncoderInterface): class Conformer(EncoderInterface):
@ -48,6 +48,26 @@ class Conformer(EncoderInterface):
layer_dropout (float): layer-dropout rate. layer_dropout (float): layer-dropout rate.
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend. vgg_frontend (bool): whether to use vgg frontend.
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
you want to train a streaming model, this is expected to be True.
When setting True, it will use a masking strategy to make the attention
see only limited left and right context.
short_chunk_threshold (float): a threshold to determinize the chunk size
to be used in masking training, if the randomly generated chunk size
is greater than ``max_len * short_chunk_threshold`` (max_len is the
max sequence length of current batch) then it will use
full context in training (i.e. with chunk size equals to max_len).
This will be used only when dynamic_chunk_training is True.
short_chunk_size (int): see docs above, if the randomly generated chunk
size equals to or less than ``max_len * short_chunk_threshold``, the
chunk size will be sampled uniformly from 1 to short_chunk_size.
This also will be used only when dynamic_chunk_training is True.
num_left_chunks (int): the left context (in chunks) attention can see, the
chunk size is decided by short_chunk_threshold and short_chunk_size.
A minus value means seeing full left context.
This also will be used only when dynamic_chunk_training is True.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training.
""" """
def __init__( def __init__(
@ -61,6 +81,11 @@ class Conformer(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
num_left_chunks: int = -1,
causal: bool = False,
) -> None: ) -> None:
super(Conformer, self).__init__() super(Conformer, self).__init__()
@ -76,6 +101,14 @@ class Conformer(EncoderInterface):
# (2) embedding: num_features -> d_model # (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_layers = num_encoder_layers
self.d_model = d_model
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.num_left_chunks = num_left_chunks
self.causal = causal
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer( encoder_layer = ConformerEncoderLayer(
@ -85,6 +118,7 @@ class Conformer(EncoderInterface):
dropout, dropout,
layer_dropout, layer_dropout,
cnn_module_kernel, cnn_module_kernel,
causal,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
@ -117,10 +151,31 @@ class Conformer(EncoderInterface):
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2 lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder( src_key_padding_mask = make_pad_mask(lengths)
x, pos_emb, src_key_padding_mask=mask, warmup=warmup mask = None
if self.dynamic_chunk_training:
assert (
self.causal
), "Causal convolution is required for streaming conformer."
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size,
num_left_chunks=self.num_left_chunks, device=x.device
)
x, _ = self.encoder(
x, pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C) ) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -128,6 +183,116 @@ class Conformer(EncoderInterface):
return x, lengths return x, lengths
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
warmup: float = 1.0,
states: Optional[Tensor] = None,
chunk_size: int = 16,
left_context: int = 64,
simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
states:
The decode states for previous frames which contains the cached data.
It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
left_context:
How many old frames the attention can see in current chunk, it MUST
be equal to left_context in decode_states.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- decode_states, the updated DecodeStates including the information
of current chunk.
"""
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
if not simulate_streaming:
assert (
decode_states is not None
), "Require cache when sending data in streaming mode"
assert (
states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model)
), f"""The shape of states MUST be equal to
(2, encoder_layers, left_context, batch, d_model) which is
{(2, self.encoder_layers, left_context, x.size(0), self.d_model)}
given {states.shape}."""
src_key_padding_mask = make_pad_mask(lengths + left_context)
embed = self.encoder_embed(x)
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x, states = self.encoder(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
states=states,
left_context=left_context,
) # (T, B, F)
else:
assert states is None
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item()
num_left_chunks = -1
if left_context >= 0:
assert left_context % chunk_size == 0
num_left_chunks = left_context // chunk_size
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=num_left_chunks,
device=x.device
)
x, _ = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths, states
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
""" """
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
@ -139,6 +304,8 @@ class ConformerEncoderLayer(nn.Module):
dim_feedforward: the dimension of the feedforward network model (default=2048). dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1). dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module. cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
Examples:: Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -155,6 +322,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
causal: bool = False,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
@ -182,7 +350,11 @@ class ConformerEncoderLayer(nn.Module):
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
) )
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) self.conv_module = ConvolutionModule(
d_model,
cnn_module_kernel,
causal=causal
)
self.norm_final = BasicNorm(d_model) self.norm_final = BasicNorm(d_model)
@ -200,7 +372,9 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
) -> Tensor: states: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Tensor]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -211,10 +385,17 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
states:
The decode states for previous frames which contains the cached data.
It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E),for streaming decoding it is (N, 2*(S+left_context)-1, E).
src_mask: (S, S). src_mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
@ -236,19 +417,38 @@ class ConformerEncoderLayer(nn.Module):
# macaron style feed forward module # macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
key = src
val = src
if not self.training and states is not None:
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([states[0, ...], src], dim=0)
val = key
states[0, ...] = key[-left_context, ...]
else:
assert left_context == 0
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att = self.self_attn(
src, src,
src, key,
src, val,
pos_emb=pos_emb, pos_emb=pos_emb,
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
left_context=left_context,
)[0] )[0]
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
# convolution module # convolution module
src = src + self.dropout(self.conv_module(src)) if not self.training and states is not None:
src = torch.cat([states[1, ...], src], dim=0)
states[1, ...] = src[-left_context, ...]
conv = self.conv_module(src)
conv = conv[-src.size(0) :, :, :] # noqa: E203
src = src + self.dropout(conv)
# feed forward module # feed forward module
src = src + self.dropout(self.feed_forward(src)) src = src + self.dropout(self.feed_forward(src))
@ -258,7 +458,7 @@ class ConformerEncoderLayer(nn.Module):
if alpha != 1.0: if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig src = alpha * src + (1 - alpha) * src_orig
return src return src, states
class ConformerEncoder(nn.Module): class ConformerEncoder(nn.Module):
@ -290,7 +490,9 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
) -> Tensor: states: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Tensor]:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
Args: Args:
@ -298,10 +500,19 @@ class ConformerEncoder(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional). mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
states:
The decode states for previous frames which contains the cached data.
It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
mask: (S, S). mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
@ -309,16 +520,26 @@ class ConformerEncoder(nn.Module):
""" """
output = src output = src
for i, mod in enumerate(self.layers): if self.training:
output = mod( assert left_context == 0
assert states is None
else:
assert left_context >= 0
for layer_index, mod in enumerate(self.layers):
output, cache = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
states=None if states is None else states[:, layer_index, ...],
left_context=left_context,
) )
if states is not None:
states[:, layer_index, ...] = cache
return output return output, states
class RelPositionalEncoding(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module):
@ -344,12 +565,13 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None: def extend_pe(self, x: Tensor, context: int = 0) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
x_size_1 = x.size(1) + context
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device x.device
@ -359,9 +581,9 @@ class RelPositionalEncoding(torch.nn.Module):
# Suppose `i` means to the position of query vecotr and `j` means the # Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys # position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j). # are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model) pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model) pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model) * -(math.log(10000.0) / self.d_model)
@ -379,24 +601,32 @@ class RelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1) pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: def forward(
self,
x: torch.Tensor,
context: int = 0
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding. """Add positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (batch, time, `*`). x (torch.Tensor): Input tensor (batch, time, `*`).
context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns: Returns:
torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
""" """
self.extend_pe(x) self.extend_pe(x, context)
x_size_1 = x.size(1) + context
pos_emb = self.pe[ pos_emb = self.pe[
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
- x.size(1) - x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203 + 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1), + x_size_1,
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -466,6 +696,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = True,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -479,6 +710,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
- Inputs: - Inputs:
@ -524,14 +758,18 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
need_weights=need_weights, need_weights=need_weights,
attn_mask=attn_mask, attn_mask=attn_mask,
left_context=left_context,
) )
def rel_shift(self, x: Tensor) -> Tensor: def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
"""Compute relative positional encoding. """Compute relative positional encoding.
Args: Args:
x: Input tensor (batch, head, time1, 2*time1-1). x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector. time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns: Returns:
Tensor: tensor of shape (batch, head, time1, time2) Tensor: tensor of shape (batch, head, time1, time2)
@ -539,14 +777,17 @@ class RelPositionMultiheadAttention(nn.Module):
the key, while time1 is for the query). the key, while time1 is for the query).
""" """
(batch_size, num_heads, time1, n) = x.shape (batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
time2 = time1 + left_context
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
# Note: TorchScript requires explicit arg for stride() # Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0) batch_stride = x.stride(0)
head_stride = x.stride(1) head_stride = x.stride(1)
time1_stride = x.stride(2) time1_stride = x.stride(2)
n_stride = x.stride(3) n_stride = x.stride(3)
return x.as_strided( return x.as_strided(
(batch_size, num_heads, time1, time1), (batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride), (batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1), storage_offset=n_stride * (time1 - 1),
) )
@ -568,6 +809,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, need_weights: bool = True,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
r""" r"""
Args: Args:
@ -585,6 +827,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
Inputs: Inputs:
@ -748,7 +993,8 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb_bsz = pos_emb.size(0) pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1 assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self._pos_bias_u()).transpose( q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2 1, 2
@ -768,9 +1014,9 @@ class RelPositionMultiheadAttention(nn.Module):
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1) q_with_bias_v, p
) # (batch, head, time1, 2*time1-1) ) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd) matrix_bd = self.rel_shift(matrix_bd, left_context)
attn_output_weights = ( attn_output_weights = (
matrix_ac + matrix_bd matrix_ac + matrix_bd
@ -805,6 +1051,24 @@ class RelPositionMultiheadAttention(nn.Module):
) )
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
# If we are using dynamic_chunk_training and setting a limited
# num_left_chunks, the attention may only see the padding values which
# will also be masked out by `key_padding_mask`, at this circumstances,
# the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
# positions to avoid invalid loss value below.
if attn_mask is not None and attn_mask.dtype == torch.bool and \
key_padding_mask is not None:
combined_mask = attn_mask.unsqueeze(
0) | key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
combined_mask, 0.0)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len)
attn_output_weights = nn.functional.dropout( attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training attn_output_weights, p=dropout_p, training=training
) )
@ -838,16 +1102,21 @@ class ConvolutionModule(nn.Module):
channels (int): The number of channels of conv layers. channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers. kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True). bias (bool): Whether to use bias in conv layers (default=True).
causal (bool): Whether to use causal convolution.
""" """
def __init__( def __init__(
self, channels: int, kernel_size: int, bias: bool = True self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = False
) -> None: ) -> None:
"""Construct an ConvolutionModule object.""" """Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__() super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding # kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
self.causal = causal
self.pointwise_conv1 = ScaledConv1d( self.pointwise_conv1 = ScaledConv1d(
channels, channels,
@ -875,12 +1144,17 @@ class ConvolutionModule(nn.Module):
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
) )
self.lorder = kernel_size - 1
padding = (kernel_size - 1) // 2
if self.causal:
padding = 0
self.depthwise_conv = ScaledConv1d( self.depthwise_conv = ScaledConv1d(
channels, channels,
channels, channels,
kernel_size, kernel_size,
stride=1, stride=1,
padding=(kernel_size - 1) // 2, padding=padding,
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
@ -921,6 +1195,10 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time) x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv # 1D Depthwise Conv
if self.causal and self.lorder > 0:
# Make depthwise_conv causal by
# manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.deriv_balancer2(x) x = self.deriv_balancer2(x)

View File

@ -53,6 +53,18 @@ Usage:
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
(5) decode in streaming mode (take greedy search as an example)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--right-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method greedy_search
""" """
@ -85,6 +97,7 @@ from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -190,6 +203,7 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
@ -198,6 +212,70 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
return parser return parser
@ -246,9 +324,19 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( if params.simulate_streaming:
x=feature, x_lens=feature_lens encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
) x=feature,
x_lens=feature_lens,
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -461,6 +549,10 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
@ -490,6 +582,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -77,6 +78,8 @@ class Transducer(nn.Module):
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
warmup: float = 1.0, warmup: float = 1.0,
delay_penalty: float = 0.0,
return_sym_delay: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -154,10 +157,31 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale, lm_only_scale=lm_scale,
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum", reduction="sum",
return_grad=True, return_grad=True,
) )
sym_delay = None
if return_sym_delay:
B, S, T0 = px_grad.shape
T = T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2,
dtype=px_grad.dtype,
device=px_grad.device,
).expand(B, 1, 1)
total_syms = S * B
else:
offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2])
offset = torch.arange(
T0, device=px_grad.device
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms
# ranges : [B, T, prune_range] # ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad, px_grad=px_grad,
@ -186,8 +210,9 @@ class Transducer(nn.Module):
symbols=y_padded, symbols=y_padded,
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
delay_penalty=delay_penalty,
boundary=boundary, boundary=boundary,
reduction="sum", reduction="sum",
) )
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss, sym_delay)

View File

@ -40,6 +40,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
# train a streaming model
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
""" """
@ -263,6 +276,59 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser return parser
@ -349,6 +415,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead, nhead=params.nhead,
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
) )
return encoder return encoder
@ -541,7 +611,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, sym_delay = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -549,6 +619,8 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
@ -577,6 +649,9 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item()
return loss, info return loss, info
@ -806,6 +881,15 @@ def run(rank, world_size, args):
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")

View File

@ -18,7 +18,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import List, Optional, Tuple from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -27,70 +27,6 @@ from transformer import Transformer
from icefall.utils import make_pad_mask, subsequent_chunk_mask from icefall.utils import make_pad_mask, subsequent_chunk_mask
class DecodeStates(object):
def __init__(self,
layers: int,
left_context: int,
dim: int,
init: bool = True,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device('cpu')):
self.layers = layers
self.left_context = left_context
self.dim = dim
self.dtype = dtype
self.device = device
if init:
# shape (layer, T, dim)
self.attn_cache = torch.zeros((layers, left_context, dim),
dtype=dtype,
device=device)
self.conv_cache = torch.zeros((layers, left_context, dim),
dtype=dtype,
device=device)
self.offset = torch.tensor([0], dtype=dtype, device=device)
@staticmethod
def stack(states: List['DecodeStates']) -> 'DecodeStates':
assert len(states) >= 1
obj = DecodeStates(layers=states[0].layers,
left_context=states[0].left_context,
dim=states[0].dim,
init=False,
dtype=states[0].dtype,
device=states[0].device)
attn_cache = []
conv_cache = []
offset = []
for i in range(len(states)):
attn_cache.append(states[i].attn_cache)
conv_cache.append(states[i].conv_cache)
offset.append(states[i].offset)
obj.attn_cache = torch.stack(attn_cache, dim=2)
obj.conv_cache = torch.stack(conv_cache, dim=2)
obj.offset = torch.stack(offset, dim=0)
return obj
@staticmethod
def unstack(states: 'DecodeStates') -> List['DecodeStates']:
results = []
attn_cache = torch.unbind(states.attn_cache, dim=2)
conv_cache = torch.unbind(states.conv_cache, dim=2)
offset = torch.unbind(states.offset, dim=0)
for i in range(states.attn_cache.size(2)):
obj = DecodeStates(layers=states.layers,
left_context=states.left_context,
dim=states.dim,
init=False,
dtype=states.dtype,
device=states.device)
obj.attn_cache = attn_cache[i]
obj.conv_cache = conv_cache[i]
obj.offset = offset[i]
results.append(obj)
return results
class Conformer(Transformer): class Conformer(Transformer):
""" """
Args: Args:
@ -119,7 +55,7 @@ class Conformer(Transformer):
size equals to or less than ``max_len * short_chunk_threshold``, the size equals to or less than ``max_len * short_chunk_threshold``, the
chunk size will be sampled uniformly from 1 to short_chunk_size. chunk size will be sampled uniformly from 1 to short_chunk_size.
This also will be used only when dynamic_chunk_training is True. This also will be used only when dynamic_chunk_training is True.
num_left_chunks (int): the left context attention can see in chunks, the num_left_chunks (int): the left context (in chunks) attention can see, the
chunk size is decided by short_chunk_threshold and short_chunk_size. chunk size is decided by short_chunk_threshold and short_chunk_size.
A minus value means seeing full left context. A minus value means seeing full left context.
This also will be used only when dynamic_chunk_training is True. This also will be used only when dynamic_chunk_training is True.
@ -159,6 +95,8 @@ class Conformer(Transformer):
vgg_frontend=vgg_frontend, vgg_frontend=vgg_frontend,
) )
self.encoder_layers = num_encoder_layers
self.d_model = d_model
self.dynamic_chunk_training = dynamic_chunk_training self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size self.short_chunk_size = short_chunk_size
@ -231,7 +169,7 @@ class Conformer(Transformer):
num_left_chunks=self.num_left_chunks, device=x.device num_left_chunks=self.num_left_chunks, device=x.device
) )
x = self.encoder( x, _ = self.encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, N, C) ) # (T, N, C)
@ -248,11 +186,11 @@ class Conformer(Transformer):
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
decode_states: Optional[DecodeStates] = None, states: Optional[torch.Tensor] = None,
chunk_size: int = 16, chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
simulate_streaming: bool = False, simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
x: x:
@ -260,9 +198,10 @@ class Conformer(Transformer):
x_lens: x_lens:
A tensor of shape (batch_size,) containing the number of frames in A tensor of shape (batch_size,) containing the number of frames in
`x` before padding. `x` before padding.
decode_states: states:
The decode states for previous frames which contains the cached data The decode states for previous frames which contains the cached data.
and the offset of current chunk in the whole sequence. It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
chunk_size: chunk_size:
The chunk size for decoding, this will be used to simulate streaming The chunk size for decoding, this will be used to simulate streaming
decoding using masking. decoding using masking.
@ -289,13 +228,15 @@ class Conformer(Transformer):
if not simulate_streaming: if not simulate_streaming:
assert ( assert (
decode_states is not None states is not None
), "Require cache when sending data in streaming mode" ), "Require cache when sending data in streaming mode"
assert ( assert (
left_context == decode_states.left_context states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model)
), f"""The given left_context must equal to the left_context in ), f"""The shape of states MUST be equal to
`decode_states`, need {decode_states.left_context} given (2, encoder_layers, left_context, batch, d_model) which is
{left_context}.""" {(2, self.encoder_layers, left_context, x.size(0), self.d_model)}
given {states.shape}."""
src_key_padding_mask = make_pad_mask(lengths + left_context) src_key_padding_mask = make_pad_mask(lengths + left_context)
@ -303,18 +244,16 @@ class Conformer(Transformer):
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x = self.encoder( x, states = self.encoder(
embed, embed,
pos_enc, pos_enc,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
attn_cache=decode_states.attn_cache, states=states,
conv_cache=decode_states.conv_cache, left_context=left_context,
left_context=decode_states.left_context,
) # (T, B, F) ) # (T, B, F)
decode_states.offset += embed.size(0)
else: else:
assert decode_states is None assert states is None
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x) x = self.encoder_embed(x)
@ -322,8 +261,11 @@ class Conformer(Transformer):
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
assert left_context % chunk_size == 0
num_left_chunks = left_context // chunk_size num_left_chunks = -1
if left_context >= 0:
assert left_context % chunk_size == 0
num_left_chunks = left_context // chunk_size
mask = ~subsequent_chunk_mask( mask = ~subsequent_chunk_mask(
size=x.size(0), size=x.size(0),
@ -331,7 +273,7 @@ class Conformer(Transformer):
num_left_chunks=num_left_chunks, num_left_chunks=num_left_chunks,
device=x.device device=x.device
) )
x = self.encoder( x, _ = self.encoder(
x, x,
pos_emb, pos_emb,
mask=mask, mask=mask,
@ -344,7 +286,7 @@ class Conformer(Transformer):
logits = self.encoder_output_layer(x) logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths, decode_states return logits, lengths, states
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):
@ -425,10 +367,9 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
attn_cache: Optional[Tensor] = None, states: Optional[Tensor] = None,
conv_cache: Optional[Tensor] = None,
left_context: int = 0, left_context: int = 0,
) -> Tensor: ) -> Tuple[Tensor, Tensor]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -437,9 +378,10 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
attn_cache: attention cache for previous frames. states: The decode states for previous frames which contains the cached data.
conv_cache: convolution cache for previous frames. It has a shape of (2, left_context, batch, attention_dim),
left_context: left context in frames used during streaming decoding. states[0,...] is the attn_cache, states[1,...] is the conv_cache.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
Shape: Shape:
@ -467,11 +409,11 @@ class ConformerEncoderLayer(nn.Module):
key = src key = src
val = src val = src
if not self.training and attn_cache is not None: if not self.training and states is not None:
# src: [chunk_size, N, F] e.g. [8, 41, 512] # src: [chunk_size, N, F] e.g. [8, 41, 512]
key = torch.cat([attn_cache, src], dim=0) key = torch.cat([states[0, ...], src], dim=0)
val = key val = key
attn_cache = key states[0, ...] = key[-left_context:, ...]
else: else:
assert left_context == 0 assert left_context == 0
@ -493,9 +435,9 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
src = self.norm_conv(src) src = self.norm_conv(src)
if not self.training and conv_cache is not None: if not self.training and states is not None:
src = torch.cat([conv_cache, src], dim=0) src = torch.cat([states[1, ...], src], dim=0)
conv_cache = src states[1, ...] = src[-left_context:, ...]
src = self.conv_module(src) src = self.conv_module(src)
src = src[-residual.size(0) :, :, :] # noqa: E203 src = src[-residual.size(0) :, :, :] # noqa: E203
@ -515,7 +457,7 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
src = self.norm_final(src) src = self.norm_final(src)
return src, attn_cache, conv_cache return src, states
class ConformerEncoder(nn.Module): class ConformerEncoder(nn.Module):
@ -546,10 +488,9 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
attn_cache: Optional[Tensor] = None, states: Optional[Tensor] = None,
conv_cache: Optional[Tensor] = None,
left_context: int = 0, left_context: int = 0,
) -> Tensor: ) -> Tuple[Tensor, Tensor]:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
Args: Args:
@ -557,9 +498,10 @@ class ConformerEncoder(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional). mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
attn_cache: attention cache for previous frames. states: The decode states for previous frames which contains the cached data.
conv_cache: convolution cache for previous frames. It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
left_context: left context in frames used during streaming decoding. states[0,...] is the attn_cache, states[1,...] is the conv_cache.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
Shape: Shape:
@ -576,26 +518,23 @@ class ConformerEncoder(nn.Module):
if self.training: if self.training:
assert left_context == 0 assert left_context == 0
assert attn_cache is None assert states is None
assert conv_cache is None
else: else:
assert left_context >= 0 assert left_context >= 0
for layer_index, mod in enumerate(self.layers): for layer_index, mod in enumerate(self.layers):
output, a_cache, c_cache = mod( output, cache = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
attn_cache=None if attn_cache is None else attn_cache[layer_index], states=None if states is None else states[:,layer_index, ...],
conv_cache=None if conv_cache is None else conv_cache[layer_index],
left_context=left_context, left_context=left_context,
) )
if attn_cache is not None and conv_cache is not None: if states is not None:
attn_cache[layer_index, ...] = a_cache[-left_context:, ...] states[:, layer_index, ...] = cache
conv_cache[layer_index, ...] = c_cache[-left_context:, ...]
return output return output, states
class RelPositionalEncoding(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module):
@ -667,7 +606,7 @@ class RelPositionalEncoding(torch.nn.Module):
Args: Args:
x (torch.Tensor): Input tensor (batch, time, `*`). x (torch.Tensor): Input tensor (batch, time, `*`).
context (int): left context in frames used during streaming decoding. context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
Returns: Returns:
@ -762,7 +701,7 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context in frames used during streaming decoding. left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
@ -819,7 +758,7 @@ class RelPositionMultiheadAttention(nn.Module):
Args: Args:
x: Input tensor (batch, head, time1, 2*time1-1). x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector. time1 means the length of query vector.
left_context (int): left context in frames used during streaming decoding. left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
@ -879,7 +818,7 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context in frames used during streaming decoding. left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.

View File

@ -535,8 +535,11 @@ class MetricsTracker(collections.defaultdict):
ans = [] ans = []
for k, v in self.items(): for k, v in self.items():
if k != "frames": if k != "frames":
norm_value = float(v) / num_frames if k != "sym_delay":
ans.append((k, norm_value)) norm_value = float(v) / num_frames
ans.append((k, norm_value))
else:
ans.append((k, float(v)))
return ans return ans
def reduce(self, device): def reduce(self, device):