mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
support streaming on pruned_transducer_stateless2; add delay penalty; fixes for decode states
This commit is contained in:
parent
118b09463d
commit
7cc697c03a
@ -66,6 +66,8 @@ class Transducer(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
|
return_sym_delay: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -136,10 +138,31 @@ class Transducer(nn.Module):
|
|||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sym_delay = None
|
||||||
|
if return_sym_delay:
|
||||||
|
B, S, T0 = px_grad.shape
|
||||||
|
T = T0 - 1
|
||||||
|
if boundary is None:
|
||||||
|
offset = torch.tensor(
|
||||||
|
(T - 1) / 2,
|
||||||
|
dtype=px_grad.dtype,
|
||||||
|
device=px_grad.device,
|
||||||
|
).expand(B, 1, 1)
|
||||||
|
total_syms = S * B
|
||||||
|
else:
|
||||||
|
offset = (boundary[:, 3] - 1) / 2
|
||||||
|
total_syms = torch.sum(boundary[:, 2])
|
||||||
|
offset = torch.arange(
|
||||||
|
T0, device=px_grad.device
|
||||||
|
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
|
||||||
|
sym_delay = px_grad * offset
|
||||||
|
sym_delay = torch.sum(sym_delay) / total_syms
|
||||||
|
|
||||||
# ranges : [B, T, prune_range]
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
px_grad=px_grad,
|
px_grad=px_grad,
|
||||||
@ -163,7 +186,8 @@ class Transducer(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss, sym_delay)
|
||||||
|
@ -269,6 +269,25 @@ def get_parser():
|
|||||||
help="How many left context can be seen in chunks when calculating attention.",
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value to penalize symbol delay, this may be
|
||||||
|
needed when training with time masking, to avoid the time masking
|
||||||
|
encouraging the network to delay symbols.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-sym-delay",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to return `sym_delay` during training, this is a stat
|
||||||
|
to measure symbols emission delay, especially for time masking training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -536,14 +555,17 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
|
sym_delay = None
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, sym_delay = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
delay_penalty=params.delay_penalty,
|
||||||
|
return_sym_delay=params.return_sym_delay,
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
||||||
|
|
||||||
@ -561,6 +583,9 @@ def compute_loss(
|
|||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
|
if sym_delay is not None:
|
||||||
|
info["sym_delay"] = sym_delay.detatch().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ from scaling import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
class Conformer(EncoderInterface):
|
class Conformer(EncoderInterface):
|
||||||
@ -48,6 +48,26 @@ class Conformer(EncoderInterface):
|
|||||||
layer_dropout (float): layer-dropout rate.
|
layer_dropout (float): layer-dropout rate.
|
||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
vgg_frontend (bool): whether to use vgg frontend.
|
||||||
|
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
|
||||||
|
you want to train a streaming model, this is expected to be True.
|
||||||
|
When setting True, it will use a masking strategy to make the attention
|
||||||
|
see only limited left and right context.
|
||||||
|
short_chunk_threshold (float): a threshold to determinize the chunk size
|
||||||
|
to be used in masking training, if the randomly generated chunk size
|
||||||
|
is greater than ``max_len * short_chunk_threshold`` (max_len is the
|
||||||
|
max sequence length of current batch) then it will use
|
||||||
|
full context in training (i.e. with chunk size equals to max_len).
|
||||||
|
This will be used only when dynamic_chunk_training is True.
|
||||||
|
short_chunk_size (int): see docs above, if the randomly generated chunk
|
||||||
|
size equals to or less than ``max_len * short_chunk_threshold``, the
|
||||||
|
chunk size will be sampled uniformly from 1 to short_chunk_size.
|
||||||
|
This also will be used only when dynamic_chunk_training is True.
|
||||||
|
num_left_chunks (int): the left context (in chunks) attention can see, the
|
||||||
|
chunk size is decided by short_chunk_threshold and short_chunk_size.
|
||||||
|
A minus value means seeing full left context.
|
||||||
|
This also will be used only when dynamic_chunk_training is True.
|
||||||
|
causal (bool): Whether to use causal convolution in conformer encoder
|
||||||
|
layer. This MUST be True when using dynamic_chunk_training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -61,6 +81,11 @@ class Conformer(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
|
dynamic_chunk_training: bool = False,
|
||||||
|
short_chunk_threshold: float = 0.75,
|
||||||
|
short_chunk_size: int = 25,
|
||||||
|
num_left_chunks: int = -1,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__()
|
super(Conformer, self).__init__()
|
||||||
|
|
||||||
@ -76,6 +101,14 @@ class Conformer(EncoderInterface):
|
|||||||
# (2) embedding: num_features -> d_model
|
# (2) embedding: num_features -> d_model
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||||
|
|
||||||
|
self.encoder_layers = num_encoder_layers
|
||||||
|
self.d_model = d_model
|
||||||
|
self.dynamic_chunk_training = dynamic_chunk_training
|
||||||
|
self.short_chunk_threshold = short_chunk_threshold
|
||||||
|
self.short_chunk_size = short_chunk_size
|
||||||
|
self.num_left_chunks = num_left_chunks
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
encoder_layer = ConformerEncoderLayer(
|
encoder_layer = ConformerEncoderLayer(
|
||||||
@ -85,6 +118,7 @@ class Conformer(EncoderInterface):
|
|||||||
dropout,
|
dropout,
|
||||||
layer_dropout,
|
layer_dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
|
causal,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
|
|
||||||
@ -117,10 +151,31 @@ class Conformer(EncoderInterface):
|
|||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
|
||||||
|
|
||||||
x = self.encoder(
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
mask = None
|
||||||
|
|
||||||
|
if self.dynamic_chunk_training:
|
||||||
|
assert (
|
||||||
|
self.causal
|
||||||
|
), "Causal convolution is required for streaming conformer."
|
||||||
|
max_len = x.size(0)
|
||||||
|
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||||
|
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||||
|
chunk_size = max_len
|
||||||
|
else:
|
||||||
|
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||||
|
|
||||||
|
mask = ~subsequent_chunk_mask(
|
||||||
|
size=x.size(0), chunk_size=chunk_size,
|
||||||
|
num_left_chunks=self.num_left_chunks, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
x, _ = self.encoder(
|
||||||
|
x, pos_emb,
|
||||||
|
mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
) # (T, N, C)
|
) # (T, N, C)
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
@ -128,6 +183,116 @@ class Conformer(EncoderInterface):
|
|||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
|
||||||
|
def streaming_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
states: Optional[Tensor] = None,
|
||||||
|
chunk_size: int = 16,
|
||||||
|
left_context: int = 64,
|
||||||
|
simulate_streaming: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`x` before padding.
|
||||||
|
warmup:
|
||||||
|
A floating point value that gradually increases from 0 throughout
|
||||||
|
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||||
|
to turn modules on sequentially.
|
||||||
|
states:
|
||||||
|
The decode states for previous frames which contains the cached data.
|
||||||
|
It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
|
||||||
|
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
|
||||||
|
chunk_size:
|
||||||
|
The chunk size for decoding, this will be used to simulate streaming
|
||||||
|
decoding using masking.
|
||||||
|
left_context:
|
||||||
|
How many old frames the attention can see in current chunk, it MUST
|
||||||
|
be equal to left_context in decode_states.
|
||||||
|
simulate_streaming:
|
||||||
|
If setting True, it will use a masking strategy to simulate streaming
|
||||||
|
fashion (i.e. every chunk data only see limited left context and
|
||||||
|
right context). The whole sequence is supposed to be send at a time
|
||||||
|
When using simulate_streaming.
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing 2 tensors:
|
||||||
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
|
of frames in `logits` before padding.
|
||||||
|
- decode_states, the updated DecodeStates including the information
|
||||||
|
of current chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# x: [N, T, C]
|
||||||
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
|
|
||||||
|
if not simulate_streaming:
|
||||||
|
assert (
|
||||||
|
decode_states is not None
|
||||||
|
), "Require cache when sending data in streaming mode"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model)
|
||||||
|
), f"""The shape of states MUST be equal to
|
||||||
|
(2, encoder_layers, left_context, batch, d_model) which is
|
||||||
|
{(2, self.encoder_layers, left_context, x.size(0), self.d_model)}
|
||||||
|
given {states.shape}."""
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(lengths + left_context)
|
||||||
|
|
||||||
|
embed = self.encoder_embed(x)
|
||||||
|
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||||
|
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
|
||||||
|
x, states = self.encoder(
|
||||||
|
embed,
|
||||||
|
pos_enc,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
states=states,
|
||||||
|
left_context=left_context,
|
||||||
|
) # (T, B, F)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert states is None
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
num_left_chunks = -1
|
||||||
|
if left_context >= 0:
|
||||||
|
assert left_context % chunk_size == 0
|
||||||
|
num_left_chunks = left_context // chunk_size
|
||||||
|
|
||||||
|
mask = ~subsequent_chunk_mask(
|
||||||
|
size=x.size(0),
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
num_left_chunks=num_left_chunks,
|
||||||
|
device=x.device
|
||||||
|
)
|
||||||
|
x, _ = self.encoder(
|
||||||
|
x,
|
||||||
|
pos_emb,
|
||||||
|
mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
return x, lengths, states
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||||
@ -139,6 +304,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
|
causal (bool): Whether to use causal convolution in conformer encoder
|
||||||
|
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||||
@ -155,6 +322,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
|
|
||||||
@ -182,7 +350,11 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(
|
||||||
|
d_model,
|
||||||
|
cnn_module_kernel,
|
||||||
|
causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.norm_final = BasicNorm(d_model)
|
||||||
|
|
||||||
@ -200,7 +372,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tensor:
|
states: Optional[Tensor] = None,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
@ -211,10 +385,17 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
bypass layers more frequently.
|
bypass layers more frequently.
|
||||||
|
states:
|
||||||
|
The decode states for previous frames which contains the cached data.
|
||||||
|
It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
|
||||||
|
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
|
||||||
|
left_context: left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*S-1, E)
|
pos_emb: (N, 2*S-1, E),for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||||
src_mask: (S, S).
|
src_mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
@ -236,19 +417,38 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||||
|
|
||||||
|
key = src
|
||||||
|
val = src
|
||||||
|
if not self.training and states is not None:
|
||||||
|
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||||
|
key = torch.cat([states[0, ...], src], dim=0)
|
||||||
|
val = key
|
||||||
|
states[0, ...] = key[-left_context, ...]
|
||||||
|
else:
|
||||||
|
assert left_context == 0
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
src,
|
src,
|
||||||
src,
|
key,
|
||||||
src,
|
val,
|
||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
|
left_context=left_context,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
src = src + self.dropout(src_att)
|
src = src + self.dropout(src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.dropout(self.conv_module(src))
|
if not self.training and states is not None:
|
||||||
|
src = torch.cat([states[1, ...], src], dim=0)
|
||||||
|
states[1, ...] = src[-left_context, ...]
|
||||||
|
|
||||||
|
conv = self.conv_module(src)
|
||||||
|
conv = conv[-src.size(0) :, :, :] # noqa: E203
|
||||||
|
|
||||||
|
src = src + self.dropout(conv)
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
@ -258,7 +458,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if alpha != 1.0:
|
if alpha != 1.0:
|
||||||
src = alpha * src + (1 - alpha) * src_orig
|
src = alpha * src + (1 - alpha) * src_orig
|
||||||
|
|
||||||
return src
|
return src, states
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(nn.Module):
|
class ConformerEncoder(nn.Module):
|
||||||
@ -290,7 +490,9 @@ class ConformerEncoder(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
) -> Tensor:
|
states: Optional[Tensor] = None,
|
||||||
|
left_context: int = 0,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -298,10 +500,19 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
mask: the mask for the src sequence (optional).
|
mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
|
bypass layers more frequently.
|
||||||
|
states:
|
||||||
|
The decode states for previous frames which contains the cached data.
|
||||||
|
It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
|
||||||
|
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
|
||||||
|
left_context: left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*S-1, E)
|
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||||
mask: (S, S).
|
mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
@ -309,16 +520,26 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
if self.training:
|
||||||
output = mod(
|
assert left_context == 0
|
||||||
|
assert states is None
|
||||||
|
else:
|
||||||
|
assert left_context >= 0
|
||||||
|
|
||||||
|
for layer_index, mod in enumerate(self.layers):
|
||||||
|
output, cache = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
states=None if states is None else states[:, layer_index, ...],
|
||||||
|
left_context=left_context,
|
||||||
)
|
)
|
||||||
|
if states is not None:
|
||||||
|
states[:, layer_index, ...] = cache
|
||||||
|
|
||||||
return output
|
return output, states
|
||||||
|
|
||||||
|
|
||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
@ -344,12 +565,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor) -> None:
|
def extend_pe(self, x: Tensor, context: int = 0) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
|
x_size_1 = x.size(1) + context
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
if self.pe.size(1) >= x_size_1 * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
x.device
|
x.device
|
||||||
@ -359,9 +581,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||||
# position of key vector. We use position relative positions when keys
|
# position of key vector. We use position relative positions when keys
|
||||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
pe_positive = torch.zeros(x_size_1, self.d_model)
|
||||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
pe_negative = torch.zeros(x_size_1, self.d_model)
|
||||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
|
||||||
div_term = torch.exp(
|
div_term = torch.exp(
|
||||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
* -(math.log(10000.0) / self.d_model)
|
* -(math.log(10000.0) / self.d_model)
|
||||||
@ -379,24 +601,32 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
context: int = 0
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
|
context (int): left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x, context)
|
||||||
|
x_size_1 = x.size(1) + context
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
- x.size(1)
|
- x_size_1
|
||||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
+ x.size(1),
|
+ x_size_1,
|
||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
@ -466,6 +696,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
left_context: int = 0,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -479,6 +710,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Inputs:
|
- Inputs:
|
||||||
@ -524,14 +758,18 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
left_context=left_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def rel_shift(self, x: Tensor) -> Tensor:
|
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
|
||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||||
time1 means the length of query vector.
|
time1 means the length of query vector.
|
||||||
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: tensor of shape (batch, head, time1, time2)
|
Tensor: tensor of shape (batch, head, time1, time2)
|
||||||
@ -539,14 +777,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
the key, while time1 is for the query).
|
the key, while time1 is for the query).
|
||||||
"""
|
"""
|
||||||
(batch_size, num_heads, time1, n) = x.shape
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
assert n == 2 * time1 - 1
|
|
||||||
|
time2 = time1 + left_context
|
||||||
|
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
|
||||||
|
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
batch_stride = x.stride(0)
|
batch_stride = x.stride(0)
|
||||||
head_stride = x.stride(1)
|
head_stride = x.stride(1)
|
||||||
time1_stride = x.stride(2)
|
time1_stride = x.stride(2)
|
||||||
n_stride = x.stride(3)
|
n_stride = x.stride(3)
|
||||||
return x.as_strided(
|
return x.as_strided(
|
||||||
(batch_size, num_heads, time1, time1),
|
(batch_size, num_heads, time1, time2),
|
||||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||||
storage_offset=n_stride * (time1 - 1),
|
storage_offset=n_stride * (time1 - 1),
|
||||||
)
|
)
|
||||||
@ -568,6 +809,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
left_context: int = 0,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -585,6 +827,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
Inputs:
|
Inputs:
|
||||||
@ -748,7 +993,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
pos_emb_bsz = pos_emb.size(0)
|
pos_emb_bsz = pos_emb.size(0)
|
||||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||||
|
p = p.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
@ -768,9 +1014,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
# compute matrix b and matrix d
|
# compute matrix b and matrix d
|
||||||
matrix_bd = torch.matmul(
|
matrix_bd = torch.matmul(
|
||||||
q_with_bias_v, p.transpose(-2, -1)
|
q_with_bias_v, p
|
||||||
) # (batch, head, time1, 2*time1-1)
|
) # (batch, head, time1, 2*time1-1)
|
||||||
matrix_bd = self.rel_shift(matrix_bd)
|
matrix_bd = self.rel_shift(matrix_bd, left_context)
|
||||||
|
|
||||||
attn_output_weights = (
|
attn_output_weights = (
|
||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
@ -805,6 +1051,24 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||||
|
|
||||||
|
# If we are using dynamic_chunk_training and setting a limited
|
||||||
|
# num_left_chunks, the attention may only see the padding values which
|
||||||
|
# will also be masked out by `key_padding_mask`, at this circumstances,
|
||||||
|
# the whole column of `attn_output_weights` will be `-inf`
|
||||||
|
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
|
||||||
|
# positions to avoid invalid loss value below.
|
||||||
|
if attn_mask is not None and attn_mask.dtype == torch.bool and \
|
||||||
|
key_padding_mask is not None:
|
||||||
|
combined_mask = attn_mask.unsqueeze(
|
||||||
|
0) | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz, num_heads, tgt_len, src_len)
|
||||||
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
|
combined_mask, 0.0)
|
||||||
|
attn_output_weights = attn_output_weights.view(
|
||||||
|
bsz * num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_output_weights = nn.functional.dropout(
|
attn_output_weights = nn.functional.dropout(
|
||||||
attn_output_weights, p=dropout_p, training=training
|
attn_output_weights, p=dropout_p, training=training
|
||||||
)
|
)
|
||||||
@ -838,16 +1102,21 @@ class ConvolutionModule(nn.Module):
|
|||||||
channels (int): The number of channels of conv layers.
|
channels (int): The number of channels of conv layers.
|
||||||
kernel_size (int): Kernerl size of conv layers.
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
causal (bool): Whether to use causal convolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, channels: int, kernel_size: int, bias: bool = True
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
causal: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
self.pointwise_conv1 = ScaledConv1d(
|
self.pointwise_conv1 = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
@ -875,12 +1144,17 @@ class ConvolutionModule(nn.Module):
|
|||||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.lorder = kernel_size - 1
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
if self.causal:
|
||||||
|
padding = 0
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
self.depthwise_conv = ScaledConv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=(kernel_size - 1) // 2,
|
padding=padding,
|
||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
@ -921,6 +1195,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
|
if self.causal and self.lorder > 0:
|
||||||
|
# Make depthwise_conv causal by
|
||||||
|
# manualy padding self.lorder zeros to the left
|
||||||
|
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
|
|
||||||
x = self.deriv_balancer2(x)
|
x = self.deriv_balancer2(x)
|
||||||
|
@ -53,6 +53,18 @@ Usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) decode in streaming mode (take greedy search as an example)
|
||||||
|
./pruned_transducer_stateless2/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--simulate-streaming 1 \
|
||||||
|
--causal-convolution 1 \
|
||||||
|
--right-chunk-size 16 \
|
||||||
|
--left-context 64 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -85,6 +97,7 @@ from icefall.utils import (
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -190,6 +203,7 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
@ -198,6 +212,70 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""How many left context can be seen in chunks when calculating attention.
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--right-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -246,9 +324,19 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
if params.simulate_streaming:
|
||||||
x=feature, x_lens=feature_lens
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
)
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
chunk_size=params.right_chunk_size,
|
||||||
|
left_context=params.left_context,
|
||||||
|
simulate_streaming=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -461,6 +549,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
@ -490,6 +582,11 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -77,6 +78,8 @@ class Transducer(nn.Module):
|
|||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
|
delay_penalty: float = 0.0,
|
||||||
|
return_sym_delay: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -154,10 +157,31 @@ class Transducer(nn.Module):
|
|||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sym_delay = None
|
||||||
|
if return_sym_delay:
|
||||||
|
B, S, T0 = px_grad.shape
|
||||||
|
T = T0 - 1
|
||||||
|
if boundary is None:
|
||||||
|
offset = torch.tensor(
|
||||||
|
(T - 1) / 2,
|
||||||
|
dtype=px_grad.dtype,
|
||||||
|
device=px_grad.device,
|
||||||
|
).expand(B, 1, 1)
|
||||||
|
total_syms = S * B
|
||||||
|
else:
|
||||||
|
offset = (boundary[:, 3] - 1) / 2
|
||||||
|
total_syms = torch.sum(boundary[:, 2])
|
||||||
|
offset = torch.arange(
|
||||||
|
T0, device=px_grad.device
|
||||||
|
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
|
||||||
|
sym_delay = px_grad * offset
|
||||||
|
sym_delay = torch.sum(sym_delay) / total_syms
|
||||||
|
|
||||||
# ranges : [B, T, prune_range]
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
px_grad=px_grad,
|
px_grad=px_grad,
|
||||||
@ -186,8 +210,9 @@ class Transducer(nn.Module):
|
|||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
|
delay_penalty=delay_penalty,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss, sym_delay)
|
||||||
|
@ -40,6 +40,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
|
# train a streaming model
|
||||||
|
./pruned_transducer_stateless2/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--exp-dir pruned_transducer_stateless/exp \
|
||||||
|
--full-libri 1 \
|
||||||
|
--dynamic-chunk-training 1 \
|
||||||
|
--causal-convolution 1 \
|
||||||
|
--short-chunk-size 25 \
|
||||||
|
--num-left-chunks 4 \
|
||||||
|
--max-duration 300
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -263,6 +276,59 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value to penalize symbol delay, this may be
|
||||||
|
needed when training with time masking, to avoid the time masking
|
||||||
|
encouraging the network to delay symbols.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-sym-delay",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to return `sym_delay` during training, this is a stat
|
||||||
|
to measure symbols emission delay, especially for time masking training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -349,6 +415,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||||
|
short_chunk_size=params.short_chunk_size,
|
||||||
|
num_left_chunks=params.num_left_chunks,
|
||||||
|
causal=params.causal_convolution,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -541,7 +611,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, sym_delay = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -549,6 +619,8 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
delay_penalty=params.delay_penalty,
|
||||||
|
return_sym_delay=params.return_sym_delay,
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
@ -577,6 +649,9 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
|
if params.return_sym_delay:
|
||||||
|
info["sym_delay"] = sym_delay.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -806,6 +881,15 @@ def run(rank, world_size, args):
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.dynamic_chunk_training:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "dynamic_chunk_training requires causal convolution"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
params.delay_penalty == 0.0
|
||||||
|
), "delay_penalty is intended for dynamic_chunk_training"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -27,70 +27,6 @@ from transformer import Transformer
|
|||||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
class DecodeStates(object):
|
|
||||||
def __init__(self,
|
|
||||||
layers: int,
|
|
||||||
left_context: int,
|
|
||||||
dim: int,
|
|
||||||
init: bool = True,
|
|
||||||
dtype: torch.dtype = torch.float32,
|
|
||||||
device: torch.device = torch.device('cpu')):
|
|
||||||
self.layers = layers
|
|
||||||
self.left_context = left_context
|
|
||||||
self.dim = dim
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
if init:
|
|
||||||
# shape (layer, T, dim)
|
|
||||||
self.attn_cache = torch.zeros((layers, left_context, dim),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
self.conv_cache = torch.zeros((layers, left_context, dim),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
self.offset = torch.tensor([0], dtype=dtype, device=device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def stack(states: List['DecodeStates']) -> 'DecodeStates':
|
|
||||||
assert len(states) >= 1
|
|
||||||
obj = DecodeStates(layers=states[0].layers,
|
|
||||||
left_context=states[0].left_context,
|
|
||||||
dim=states[0].dim,
|
|
||||||
init=False,
|
|
||||||
dtype=states[0].dtype,
|
|
||||||
device=states[0].device)
|
|
||||||
attn_cache = []
|
|
||||||
conv_cache = []
|
|
||||||
offset = []
|
|
||||||
for i in range(len(states)):
|
|
||||||
attn_cache.append(states[i].attn_cache)
|
|
||||||
conv_cache.append(states[i].conv_cache)
|
|
||||||
offset.append(states[i].offset)
|
|
||||||
obj.attn_cache = torch.stack(attn_cache, dim=2)
|
|
||||||
obj.conv_cache = torch.stack(conv_cache, dim=2)
|
|
||||||
obj.offset = torch.stack(offset, dim=0)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def unstack(states: 'DecodeStates') -> List['DecodeStates']:
|
|
||||||
results = []
|
|
||||||
attn_cache = torch.unbind(states.attn_cache, dim=2)
|
|
||||||
conv_cache = torch.unbind(states.conv_cache, dim=2)
|
|
||||||
offset = torch.unbind(states.offset, dim=0)
|
|
||||||
for i in range(states.attn_cache.size(2)):
|
|
||||||
obj = DecodeStates(layers=states.layers,
|
|
||||||
left_context=states.left_context,
|
|
||||||
dim=states.dim,
|
|
||||||
init=False,
|
|
||||||
dtype=states.dtype,
|
|
||||||
device=states.device)
|
|
||||||
obj.attn_cache = attn_cache[i]
|
|
||||||
obj.conv_cache = conv_cache[i]
|
|
||||||
obj.offset = offset[i]
|
|
||||||
results.append(obj)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -119,7 +55,7 @@ class Conformer(Transformer):
|
|||||||
size equals to or less than ``max_len * short_chunk_threshold``, the
|
size equals to or less than ``max_len * short_chunk_threshold``, the
|
||||||
chunk size will be sampled uniformly from 1 to short_chunk_size.
|
chunk size will be sampled uniformly from 1 to short_chunk_size.
|
||||||
This also will be used only when dynamic_chunk_training is True.
|
This also will be used only when dynamic_chunk_training is True.
|
||||||
num_left_chunks (int): the left context attention can see in chunks, the
|
num_left_chunks (int): the left context (in chunks) attention can see, the
|
||||||
chunk size is decided by short_chunk_threshold and short_chunk_size.
|
chunk size is decided by short_chunk_threshold and short_chunk_size.
|
||||||
A minus value means seeing full left context.
|
A minus value means seeing full left context.
|
||||||
This also will be used only when dynamic_chunk_training is True.
|
This also will be used only when dynamic_chunk_training is True.
|
||||||
@ -159,6 +95,8 @@ class Conformer(Transformer):
|
|||||||
vgg_frontend=vgg_frontend,
|
vgg_frontend=vgg_frontend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.encoder_layers = num_encoder_layers
|
||||||
|
self.d_model = d_model
|
||||||
self.dynamic_chunk_training = dynamic_chunk_training
|
self.dynamic_chunk_training = dynamic_chunk_training
|
||||||
self.short_chunk_threshold = short_chunk_threshold
|
self.short_chunk_threshold = short_chunk_threshold
|
||||||
self.short_chunk_size = short_chunk_size
|
self.short_chunk_size = short_chunk_size
|
||||||
@ -231,7 +169,7 @@ class Conformer(Transformer):
|
|||||||
num_left_chunks=self.num_left_chunks, device=x.device
|
num_left_chunks=self.num_left_chunks, device=x.device
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self.encoder(
|
x, _ = self.encoder(
|
||||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||||
) # (T, N, C)
|
) # (T, N, C)
|
||||||
|
|
||||||
@ -248,11 +186,11 @@ class Conformer(Transformer):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
decode_states: Optional[DecodeStates] = None,
|
states: Optional[torch.Tensor] = None,
|
||||||
chunk_size: int = 16,
|
chunk_size: int = 16,
|
||||||
left_context: int = 64,
|
left_context: int = 64,
|
||||||
simulate_streaming: bool = False,
|
simulate_streaming: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
@ -260,9 +198,10 @@ class Conformer(Transformer):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
`x` before padding.
|
`x` before padding.
|
||||||
decode_states:
|
states:
|
||||||
The decode states for previous frames which contains the cached data
|
The decode states for previous frames which contains the cached data.
|
||||||
and the offset of current chunk in the whole sequence.
|
It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
|
||||||
|
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
|
||||||
chunk_size:
|
chunk_size:
|
||||||
The chunk size for decoding, this will be used to simulate streaming
|
The chunk size for decoding, this will be used to simulate streaming
|
||||||
decoding using masking.
|
decoding using masking.
|
||||||
@ -289,13 +228,15 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
if not simulate_streaming:
|
if not simulate_streaming:
|
||||||
assert (
|
assert (
|
||||||
decode_states is not None
|
states is not None
|
||||||
), "Require cache when sending data in streaming mode"
|
), "Require cache when sending data in streaming mode"
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
left_context == decode_states.left_context
|
states.shape == (2, self.encoder_layers, left_context, x.size(0), self.d_model)
|
||||||
), f"""The given left_context must equal to the left_context in
|
), f"""The shape of states MUST be equal to
|
||||||
`decode_states`, need {decode_states.left_context} given
|
(2, encoder_layers, left_context, batch, d_model) which is
|
||||||
{left_context}."""
|
{(2, self.encoder_layers, left_context, x.size(0), self.d_model)}
|
||||||
|
given {states.shape}."""
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths + left_context)
|
src_key_padding_mask = make_pad_mask(lengths + left_context)
|
||||||
|
|
||||||
@ -303,18 +244,16 @@ class Conformer(Transformer):
|
|||||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
|
||||||
x = self.encoder(
|
x, states = self.encoder(
|
||||||
embed,
|
embed,
|
||||||
pos_enc,
|
pos_enc,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
attn_cache=decode_states.attn_cache,
|
states=states,
|
||||||
conv_cache=decode_states.conv_cache,
|
left_context=left_context,
|
||||||
left_context=decode_states.left_context,
|
|
||||||
) # (T, B, F)
|
) # (T, B, F)
|
||||||
|
|
||||||
decode_states.offset += embed.size(0)
|
|
||||||
else:
|
else:
|
||||||
assert decode_states is None
|
assert states is None
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
@ -322,8 +261,11 @@ class Conformer(Transformer):
|
|||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
assert left_context % chunk_size == 0
|
|
||||||
num_left_chunks = left_context // chunk_size
|
num_left_chunks = -1
|
||||||
|
if left_context >= 0:
|
||||||
|
assert left_context % chunk_size == 0
|
||||||
|
num_left_chunks = left_context // chunk_size
|
||||||
|
|
||||||
mask = ~subsequent_chunk_mask(
|
mask = ~subsequent_chunk_mask(
|
||||||
size=x.size(0),
|
size=x.size(0),
|
||||||
@ -331,7 +273,7 @@ class Conformer(Transformer):
|
|||||||
num_left_chunks=num_left_chunks,
|
num_left_chunks=num_left_chunks,
|
||||||
device=x.device
|
device=x.device
|
||||||
)
|
)
|
||||||
x = self.encoder(
|
x, _ = self.encoder(
|
||||||
x,
|
x,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
@ -344,7 +286,7 @@ class Conformer(Transformer):
|
|||||||
logits = self.encoder_output_layer(x)
|
logits = self.encoder_output_layer(x)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return logits, lengths, decode_states
|
return logits, lengths, states
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
@ -425,10 +367,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
attn_cache: Optional[Tensor] = None,
|
states: Optional[Tensor] = None,
|
||||||
conv_cache: Optional[Tensor] = None,
|
|
||||||
left_context: int = 0,
|
left_context: int = 0,
|
||||||
) -> Tensor:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
@ -437,9 +378,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
attn_cache: attention cache for previous frames.
|
states: The decode states for previous frames which contains the cached data.
|
||||||
conv_cache: convolution cache for previous frames.
|
It has a shape of (2, left_context, batch, attention_dim),
|
||||||
left_context: left context in frames used during streaming decoding.
|
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
|
||||||
|
left_context: left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
Shape:
|
Shape:
|
||||||
@ -467,11 +409,11 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
key = src
|
key = src
|
||||||
val = src
|
val = src
|
||||||
if not self.training and attn_cache is not None:
|
if not self.training and states is not None:
|
||||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||||
key = torch.cat([attn_cache, src], dim=0)
|
key = torch.cat([states[0, ...], src], dim=0)
|
||||||
val = key
|
val = key
|
||||||
attn_cache = key
|
states[0, ...] = key[-left_context:, ...]
|
||||||
else:
|
else:
|
||||||
assert left_context == 0
|
assert left_context == 0
|
||||||
|
|
||||||
@ -493,9 +435,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_conv(src)
|
src = self.norm_conv(src)
|
||||||
|
|
||||||
if not self.training and conv_cache is not None:
|
if not self.training and states is not None:
|
||||||
src = torch.cat([conv_cache, src], dim=0)
|
src = torch.cat([states[1, ...], src], dim=0)
|
||||||
conv_cache = src
|
states[1, ...] = src[-left_context:, ...]
|
||||||
|
|
||||||
src = self.conv_module(src)
|
src = self.conv_module(src)
|
||||||
src = src[-residual.size(0) :, :, :] # noqa: E203
|
src = src[-residual.size(0) :, :, :] # noqa: E203
|
||||||
@ -515,7 +457,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_final(src)
|
src = self.norm_final(src)
|
||||||
|
|
||||||
return src, attn_cache, conv_cache
|
return src, states
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(nn.Module):
|
class ConformerEncoder(nn.Module):
|
||||||
@ -546,10 +488,9 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
attn_cache: Optional[Tensor] = None,
|
states: Optional[Tensor] = None,
|
||||||
conv_cache: Optional[Tensor] = None,
|
|
||||||
left_context: int = 0,
|
left_context: int = 0,
|
||||||
) -> Tensor:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -557,9 +498,10 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
mask: the mask for the src sequence (optional).
|
mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
attn_cache: attention cache for previous frames.
|
states: The decode states for previous frames which contains the cached data.
|
||||||
conv_cache: convolution cache for previous frames.
|
It has a shape of (2, encoder_layers, left_context, batch, attention_dim),
|
||||||
left_context: left context in frames used during streaming decoding.
|
states[0,...] is the attn_cache, states[1,...] is the conv_cache.
|
||||||
|
left_context: left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
Shape:
|
Shape:
|
||||||
@ -576,26 +518,23 @@ class ConformerEncoder(nn.Module):
|
|||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
assert left_context == 0
|
assert left_context == 0
|
||||||
assert attn_cache is None
|
assert states is None
|
||||||
assert conv_cache is None
|
|
||||||
else:
|
else:
|
||||||
assert left_context >= 0
|
assert left_context >= 0
|
||||||
|
|
||||||
for layer_index, mod in enumerate(self.layers):
|
for layer_index, mod in enumerate(self.layers):
|
||||||
output, a_cache, c_cache = mod(
|
output, cache = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
attn_cache=None if attn_cache is None else attn_cache[layer_index],
|
states=None if states is None else states[:,layer_index, ...],
|
||||||
conv_cache=None if conv_cache is None else conv_cache[layer_index],
|
|
||||||
left_context=left_context,
|
left_context=left_context,
|
||||||
)
|
)
|
||||||
if attn_cache is not None and conv_cache is not None:
|
if states is not None:
|
||||||
attn_cache[layer_index, ...] = a_cache[-left_context:, ...]
|
states[:, layer_index, ...] = cache
|
||||||
conv_cache[layer_index, ...] = c_cache[-left_context:, ...]
|
|
||||||
|
|
||||||
return output
|
return output, states
|
||||||
|
|
||||||
|
|
||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
@ -667,7 +606,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
context (int): left context in frames used during streaming decoding.
|
context (int): left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
Returns:
|
Returns:
|
||||||
@ -762,7 +701,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
left_context (int): left context in frames used during streaming decoding.
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
|
|
||||||
@ -819,7 +758,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||||
time1 means the length of query vector.
|
time1 means the length of query vector.
|
||||||
left_context (int): left context in frames used during streaming decoding.
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
|
|
||||||
@ -879,7 +818,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
left_context (int): left context in frames used during streaming decoding.
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
|
|
||||||
|
@ -535,8 +535,11 @@ class MetricsTracker(collections.defaultdict):
|
|||||||
ans = []
|
ans = []
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
if k != "frames":
|
if k != "frames":
|
||||||
norm_value = float(v) / num_frames
|
if k != "sym_delay":
|
||||||
ans.append((k, norm_value))
|
norm_value = float(v) / num_frames
|
||||||
|
ans.append((k, norm_value))
|
||||||
|
else:
|
||||||
|
ans.append((k, float(v)))
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def reduce(self, device):
|
def reduce(self, device):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user