mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
support streaming in conformer
This commit is contained in:
parent
6f7860a0a6
commit
5bd2490b44
@ -249,6 +249,25 @@ def get_parser():
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--streaming-mode",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--right-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="right context to attend during decoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context to attend during decoding",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -301,6 +320,15 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.streaming_mode:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.right_chunk_size,
|
||||
left_context=params.left_context,
|
||||
streaming_data=False
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
@ -526,6 +554,10 @@ def main():
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.streaming_mode:
|
||||
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-use-LG-{params.use_LG}"
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
@ -561,6 +593,10 @@ def main():
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
# TODO(wei kang): make following config more elegant
|
||||
params.dynamic_chunk_training=params.streaming_mode
|
||||
params.short_chunk_size=25
|
||||
params.num_left_chunks=params.left_context // params.right_chunk_size
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
|
@ -222,6 +222,29 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="chunk length of dynamic training",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="chunk length of dynamic training",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -310,6 +333,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||
short_chunk_size=params.short_chunk_size,
|
||||
num_left_chunks=params.num_left_chunks,
|
||||
causal=True if params.dynamic_chunk_training else False,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
@ -18,13 +18,77 @@
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Transformer
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||
|
||||
|
||||
class DecodeStates(object):
|
||||
def __init__(self,
|
||||
layers: int,
|
||||
left_context: int,
|
||||
dim: int,
|
||||
init: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
self.layers = layers
|
||||
self.left_context = left_context
|
||||
self.dim = dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if init:
|
||||
# shape (layer, T, dim)
|
||||
self.attn_cache = torch.zeros((layers, left_context, dim),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.conv_cache = torch.zeros((layers, left_context, dim),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.offset = torch.tensor([0], dtype=dtype, device=device)
|
||||
|
||||
@staticmethod
|
||||
def stack(states: List['DecodeStates']) -> 'DecodeStates':
|
||||
assert len(states) >= 1
|
||||
obj = DecodeStates(layers=states[0].layers,
|
||||
left_context=states[0].left_context,
|
||||
dim=states[0].dim,
|
||||
init=False,
|
||||
dtype=states[0].dtype,
|
||||
device=states[0].device)
|
||||
attn_cache = []
|
||||
conv_cache = []
|
||||
offset = []
|
||||
for i in range(len(states)):
|
||||
attn_cache.append(states[i].attn_cache)
|
||||
conv_cache.append(states[i].conv_cache)
|
||||
offset.append(states[i].offset)
|
||||
obj.attn_cache = torch.stack(attn_cache, dim=2)
|
||||
obj.conv_cache = torch.stack(conv_cache, dim=2)
|
||||
obj.offset = torch.stack(offset, dim=0)
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def unstack(states: 'DecodeStates') -> List['DecodeStates']:
|
||||
results = []
|
||||
attn_cache = torch.unbind(states.attn_cache, dim=2)
|
||||
conv_cache = torch.unbind(states.conv_cache, dim=2)
|
||||
offset = torch.unbind(states.offset, dim=0)
|
||||
for i in range(states.attn_cache.size(2)):
|
||||
obj = DecodeStates(layers=states.layers,
|
||||
left_context=states.left_context,
|
||||
dim=states.dim,
|
||||
init=False,
|
||||
dtype=states.dtype,
|
||||
device=states.device)
|
||||
obj.attn_cache = attn_cache[i]
|
||||
obj.conv_cache = conv_cache[i]
|
||||
obj.offset = offset[i]
|
||||
results.append(obj)
|
||||
return results
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
@ -56,6 +120,11 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
dynamic_chunk_training: bool = False,
|
||||
short_chunk_threshold: float = 0.75,
|
||||
short_chunk_size: int = 25,
|
||||
num_left_chunks: int = -1,
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
@ -70,6 +139,12 @@ class Conformer(Transformer):
|
||||
vgg_frontend=vgg_frontend,
|
||||
)
|
||||
|
||||
self.dynamic_chunk_training = dynamic_chunk_training
|
||||
self.short_chunk_threshold = short_chunk_threshold
|
||||
self.short_chunk_size = short_chunk_size
|
||||
self.num_left_chunks = num_left_chunks
|
||||
self.causal = causal
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
@ -79,6 +154,7 @@ class Conformer(Transformer):
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
normalize_before,
|
||||
causal,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
self.normalize_before = normalize_before
|
||||
@ -115,9 +191,29 @@ class Conformer(Transformer):
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
mask = None
|
||||
|
||||
if self.dynamic_chunk_training:
|
||||
assert (
|
||||
self.causal
|
||||
), "Causal convolution is required for streaming conformer."
|
||||
max_len = x.size(0)
|
||||
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0), chunk_size=chunk_size,
|
||||
num_left_chunks=self.num_left_chunks, device=x.device
|
||||
)
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||
) # (T, N, C)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
@ -128,6 +224,80 @@ class Conformer(Transformer):
|
||||
return logits, lengths
|
||||
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
decode_states: Optional[DecodeStates] = None,
|
||||
chunk_size: int = 32,
|
||||
left_context: int = 64,
|
||||
streaming_data: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]:
|
||||
# x: [N, T, C]
|
||||
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
|
||||
if streaming_data:
|
||||
assert (
|
||||
decode_states is not None
|
||||
), "Require cache when sending data in streaming mode"
|
||||
assert (
|
||||
left_context == decode_states.left_context
|
||||
), f"""The given left_context must equal to the left_context in
|
||||
`decode_states`, need {decode_states.left_context} given
|
||||
{left_context}."""
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths + left_context)
|
||||
|
||||
embed = self.encoder_embed(x)
|
||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
|
||||
x = self.encoder(
|
||||
embed,
|
||||
pos_enc,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
attn_cache=decode_states.attn_cache,
|
||||
conv_cache=decode_states.conv_cache,
|
||||
left_context=decode_states.left_context,
|
||||
) # (T, B, F)
|
||||
|
||||
decode_states.offset += embed.size(0)
|
||||
else:
|
||||
assert decode_states is None
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
assert left_context % chunk_size == 0
|
||||
num_left_chunks = left_context // chunk_size
|
||||
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0),
|
||||
chunk_size=chunk_size,
|
||||
num_left_chunks=num_left_chunks,
|
||||
device=x.device
|
||||
)
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
) # (T, N, C)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
|
||||
logits = self.encoder_output_layer(x)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return logits, lengths, decode_states
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||
@ -156,6 +326,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
@ -176,7 +347,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
self.conv_module = ConvolutionModule(
|
||||
d_model, cnn_module_kernel, causal=causal
|
||||
)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
@ -201,6 +374,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
attn_cache: Optional[Tensor] = None,
|
||||
conv_cache: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
@ -233,13 +409,25 @@ class ConformerEncoderLayer(nn.Module):
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
|
||||
key = src
|
||||
val = src
|
||||
if not self.training and attn_cache is not None:
|
||||
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||
key = torch.cat([attn_cache, src], dim=0)
|
||||
val = key
|
||||
attn_cache = key
|
||||
else:
|
||||
assert left_context == 0
|
||||
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
key,
|
||||
val,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
left_context=left_context,
|
||||
)[0]
|
||||
src = residual + self.dropout(src_att)
|
||||
if not self.normalize_before:
|
||||
@ -249,7 +437,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
src = residual + self.dropout(self.conv_module(src))
|
||||
|
||||
if not self.training and conv_cache is not None:
|
||||
src = torch.cat([conv_cache, src], dim=0)
|
||||
conv_cache = src
|
||||
|
||||
src = self.conv_module(src)
|
||||
src = src[-residual.size(0) :, :, :] # noqa: E203
|
||||
|
||||
src = residual + self.dropout(src)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
|
||||
@ -264,7 +460,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
if self.normalize_before:
|
||||
src = self.norm_final(src)
|
||||
|
||||
return src
|
||||
return src, attn_cache, conv_cache
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
@ -295,6 +491,9 @@ class ConformerEncoder(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
attn_cache: Optional[Tensor] = None,
|
||||
conv_cache: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -314,13 +513,26 @@ class ConformerEncoder(nn.Module):
|
||||
"""
|
||||
output = src
|
||||
|
||||
for mod in self.layers:
|
||||
output = mod(
|
||||
if self.training:
|
||||
assert left_context == 0
|
||||
assert attn_cache is None
|
||||
assert conv_cache is None
|
||||
else:
|
||||
assert left_context >= 0
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
output, a_cache, c_cache = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
attn_cache=None if attn_cache is None else attn_cache[layer_index],
|
||||
conv_cache=None if conv_cache is None else conv_cache[layer_index],
|
||||
left_context=left_context,
|
||||
)
|
||||
if attn_cache is not None and conv_cache is not None:
|
||||
attn_cache[layer_index, ...] = a_cache[-left_context:, ...]
|
||||
conv_cache[layer_index, ...] = c_cache[-left_context:, ...]
|
||||
|
||||
return output
|
||||
|
||||
@ -349,12 +561,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
def extend_pe(self, x: Tensor, context: int = 0) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
x_size_1 = x.size(1) + context
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
if self.pe.size(1) >= x_size_1 * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
@ -364,9 +577,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
pe_positive = torch.zeros(x_size_1, self.d_model)
|
||||
pe_negative = torch.zeros(x_size_1, self.d_model)
|
||||
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
@ -384,7 +597,11 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
@ -395,14 +612,15 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
self.extend_pe(x, context)
|
||||
x = x * self.xscale
|
||||
x_size_1 = x.size(1) + context
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
- x_size_1
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
+ x_size_1,
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
@ -467,6 +685,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -525,9 +744,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
left_context=left_context,
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
@ -540,14 +760,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
time2 = time1 + left_context
|
||||
|
||||
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
|
||||
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_size, num_heads, time1, time2),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
@ -569,6 +792,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -748,7 +972,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||
1, 2
|
||||
@ -768,9 +994,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
q_with_bias_v, p
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
|
||||
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
@ -805,6 +1032,24 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
|
||||
# If we are using dynamic_chunk_training and setting a limited
|
||||
# num_left_chunks, the attention may only see the padding values which
|
||||
# will also be masked out by `key_padding_mask`, at this circumstances,
|
||||
# the whole column of `attn_output_weights` will be `-inf`
|
||||
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
|
||||
# positions to avoid invalid loss value below.
|
||||
if attn_mask is not None and attn_mask.dtype == torch.bool and \
|
||||
key_padding_mask is not None:
|
||||
combined_mask = attn_mask.unsqueeze(
|
||||
0) | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
combined_mask, 0.0)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
@ -842,12 +1087,17 @@ class ConvolutionModule(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
causal: bool = False
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
self.causal = causal
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
@ -857,12 +1107,18 @@ class ConvolutionModule(nn.Module):
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.lorder = kernel_size - 1
|
||||
padding = (kernel_size - 1) // 2
|
||||
if self.causal:
|
||||
padding = 0
|
||||
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
@ -895,6 +1151,11 @@ class ConvolutionModule(nn.Module):
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
if self.causal and self.lorder > 0:
|
||||
# Make depthwise_conv causal by
|
||||
# manualy padding self.lorder zeros to the left
|
||||
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||
|
||||
x = self.depthwise_conv(x)
|
||||
# x is (batch, channels, time)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
@ -61,5 +61,6 @@ from .utils import (
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
subsequent_chunk_mask,
|
||||
write_error_stats,
|
||||
)
|
||||
|
@ -693,6 +693,42 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
return expaned_lengths >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
||||
def subsequent_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
for i in range(size):
|
||||
if num_left_chunks < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||||
ending = min((i // chunk_size + 1) * chunk_size, size)
|
||||
ret[i, start:ending] = True
|
||||
return ret
|
||||
|
||||
|
||||
def l1_norm(x):
|
||||
return torch.sum(torch.abs(x))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user